@@ -34,90 +34,92 @@ static void rope_yarn(
3434 *sin_theta = sycl::sin (theta) * mscale;
3535}
3636
37- template <typename T, bool has_ff>
38- static void rope_norm (
39- const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
40- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
41- const sycl::nd_item<3 > &item_ct1) {
42- const int i0 = 2 * (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) +
43- item_ct1.get_local_id (1 ));
37+ template <typename T, bool has_ff>
38+ static void rope_norm (const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
39+ const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
40+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
41+ const sycl::nd_item<3 > & item_ct1) {
42+ const int i0 = 2 * (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) + item_ct1.get_local_id (1 ));
4443
4544 if (i0 >= ne0) {
4645 return ;
4746 }
4847
49- const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
50- item_ct1.get_local_id (2 );
48+ const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) + item_ct1.get_local_id (2 );
5149
5250 if (i0 >= n_dims) {
53- const int i = row* ne0 + i0;
51+ const int i = row * ne0 + i0;
5452
5553 dst[i + 0 ] = x[i + 0 ];
5654 dst[i + 1 ] = x[i + 1 ];
5755
5856 return ;
5957 }
6058
61- const int i = row*ne0 + i0;
62- const int i2 = row/p_delta_rows;
59+ const int row0 = row % ne1;
60+ const int channel0 = row / ne1;
61+
62+ const int i = row * ne0 + i0;
63+ const int i2 = channel0 * s2 + row0 * s1 + i0;
6364
64- const float theta_base = pos[i2 ] * sycl::pow (theta_scale, i0 / 2 .0f );
65+ const float theta_base = pos[channel0 ] * sycl::pow (theta_scale, i0 / 2 .0f );
6566
66- const float freq_factor = has_ff ? freq_factors[i0/ 2 ] : 1 .0f ;
67+ const float freq_factor = has_ff ? freq_factors[i0 / 2 ] : 1 .0f ;
6768
6869 float cos_theta;
6970 float sin_theta;
7071
71- rope_yarn (theta_base/ freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
72+ rope_yarn (theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
7273
73- const float x0 = x[i + 0 ];
74- const float x1 = x[i + 1 ];
74+ const float x0 = x[i2 + 0 ];
75+ const float x1 = x[i2 + 1 ];
7576
76- dst[i + 0 ] = x0* cos_theta - x1* sin_theta;
77- dst[i + 1 ] = x0* sin_theta + x1* cos_theta;
77+ dst[i + 0 ] = x0 * cos_theta - x1 * sin_theta;
78+ dst[i + 1 ] = x0 * sin_theta + x1 * cos_theta;
7879}
7980
80- template <typename T, bool has_ff>
81- static void rope_neox (
82- const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
83- float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
84- const sycl::nd_item<3 > &item_ct1) {
85- const int i0 = 2 * (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) +
86- item_ct1.get_local_id (1 ));
81+ template <typename T, bool has_ff>
82+ static void rope_neox (const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
83+ const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
84+ const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
85+ const sycl::nd_item<3 > & item_ct1) {
86+ const int i0 = 2 * (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) + item_ct1.get_local_id (1 ));
8787
8888 if (i0 >= ne0) {
8989 return ;
9090 }
9191
92- const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
93- item_ct1.get_local_id (2 );
92+ const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) + item_ct1.get_local_id (2 );
9493
9594 if (i0 >= n_dims) {
96- const int i = row* ne0 + i0;
95+ const int i = row * ne0 + i0;
9796
9897 dst[i + 0 ] = x[i + 0 ];
9998 dst[i + 1 ] = x[i + 1 ];
10099
101100 return ;
102101 }
103102
104- const int i = row*ne0 + i0/2 ;
105- const int i2 = row/p_delta_rows;
103+ const int row0 = row % ne1;
104+ const int channel0 = row / ne1;
105+
106+ const int i = row * ne0 + i0 / 2 ;
107+ const int i2 = channel0 * s2 + row0 * s1 + i0 / 2 ;
106108
107- const float theta_base = pos[i2 ] * sycl::pow (theta_scale, i0 / 2 .0f );
109+ const float theta_base = pos[channel0 ] * sycl::pow (theta_scale, i0 / 2 .0f );
108110
109- const float freq_factor = has_ff ? freq_factors[i0/ 2 ] : 1 .0f ;
111+ const float freq_factor = has_ff ? freq_factors[i0 / 2 ] : 1 .0f ;
110112
111113 float cos_theta;
112114 float sin_theta;
113115
114- rope_yarn (theta_base/ freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
116+ rope_yarn (theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
115117
116- const float x0 = x[i + 0 ];
117- const float x1 = x[i + n_dims/ 2 ];
118+ const float x0 = x[i2 + 0 ];
119+ const float x1 = x[i2 + n_dims / 2 ];
118120
119- dst[i + 0 ] = x0* cos_theta - x1* sin_theta;
120- dst[i + n_dims/ 2 ] = x0* sin_theta + x1* cos_theta;
121+ dst[i + 0 ] = x0 * cos_theta - x1 * sin_theta;
122+ dst[i + n_dims / 2 ] = x0 * sin_theta + x1 * cos_theta;
121123}
122124
123125template <typename T, bool has_ff>
@@ -163,80 +165,66 @@ static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, cons
163165}
164166
165167template <typename T>
166- static void rope_norm_sycl (
167- const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
168- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
168+ static void rope_norm_sycl (const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
169+ const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
170+ const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
171+ const float * freq_factors, queue_ptr stream) {
169172 GGML_ASSERT (ne0 % 2 == 0 );
170173 const sycl::range<3 > block_dims (1 , SYCL_ROPE_BLOCK_SIZE, 1 );
171- const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 * SYCL_ROPE_BLOCK_SIZE);
174+ const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 * SYCL_ROPE_BLOCK_SIZE);
172175 const sycl::range<3 > block_nums (1 , num_blocks_x, nr);
173176
174- const float theta_scale = powf (freq_base, -2 .0f / n_dims);
177+ const float theta_scale = powf (freq_base, -2 .0f / n_dims);
175178
176- dpct::has_capability_or_fail (stream->get_device (),
177- {sycl::aspect::fp16});
179+ dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
178180
179181 if (freq_factors == nullptr ) {
180182 /*
181183 DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
182184 the limit. To get the device limit, query
183185 info::device::max_work_group_size. Adjust the work-group size if needed.
184186 */
185- stream->parallel_for (
186- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
187- [=](sycl::nd_item<3 > item_ct1) {
188- rope_norm<T, false >(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
189- ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
190- item_ct1);
191- });
187+ stream->parallel_for (sycl::nd_range<3 >(block_nums * block_dims, block_dims), [=](sycl::nd_item<3 > item_ct1) {
188+ rope_norm<T, false >(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
189+ theta_scale, freq_factors, item_ct1);
190+ });
192191 } else {
193192 /*
194193 DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
195194 the limit. To get the device limit, query
196195 info::device::max_work_group_size. Adjust the work-group size if needed.
197196 */
198- stream->parallel_for (
199- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
200- [=](sycl::nd_item<3 > item_ct1) {
201- rope_norm<T, true >(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
202- ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
203- item_ct1);
204- });
197+ stream->parallel_for (sycl::nd_range<3 >(block_nums * block_dims, block_dims), [=](sycl::nd_item<3 > item_ct1) {
198+ rope_norm<T, true >(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
199+ theta_scale, freq_factors, item_ct1);
200+ });
205201 }
206202}
207203
208204template <typename T>
209- static void rope_neox_sycl (
210- const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
211- float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
205+ static void rope_neox_sycl (const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
206+ const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
207+ const float freq_base, const float ext_factor, const float attn_factor,
208+ const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
212209 GGML_ASSERT (ne0 % 2 == 0 );
213210 const sycl::range<3 > block_dims (1 , SYCL_ROPE_BLOCK_SIZE, 1 );
214- const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 * SYCL_ROPE_BLOCK_SIZE);
211+ const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 * SYCL_ROPE_BLOCK_SIZE);
215212 const sycl::range<3 > block_nums (1 , num_blocks_x, nr);
216213
217- const float theta_scale = powf (freq_base, -2 .0f / n_dims);
214+ const float theta_scale = powf (freq_base, -2 .0f / n_dims);
218215
219- dpct::has_capability_or_fail (stream->get_device (),
220- {sycl::aspect::fp16});
216+ dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
221217
222218 if (freq_factors == nullptr ) {
223- stream->parallel_for (
224- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
225- [=](sycl::nd_item<3 > item_ct1) {
226- rope_neox<T, false >(x, dst, ne0, n_dims, pos, freq_scale,
227- p_delta_rows, ext_factor, attn_factor,
228- corr_dims, theta_scale, freq_factors,
229- item_ct1);
230- });
219+ stream->parallel_for (sycl::nd_range<3 >(block_nums * block_dims, block_dims), [=](sycl::nd_item<3 > item_ct1) {
220+ rope_neox<T, false >(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
221+ theta_scale, freq_factors, item_ct1);
222+ });
231223 } else {
232- stream->parallel_for (
233- sycl::nd_range<3 >(block_nums * block_dims, block_dims),
234- [=](sycl::nd_item<3 > item_ct1) {
235- rope_neox<T, true >(x, dst, ne0, n_dims, pos, freq_scale,
236- p_delta_rows, ext_factor, attn_factor,
237- corr_dims, theta_scale, freq_factors,
238- item_ct1);
239- });
224+ stream->parallel_for (sycl::nd_range<3 >(block_nums * block_dims, block_dims), [=](sycl::nd_item<3 > item_ct1) {
225+ rope_neox<T, true >(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
226+ theta_scale, freq_factors, item_ct1);
227+ });
240228 }
241229}
242230
@@ -272,7 +260,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
272260 }
273261}
274262
275- void ggml_sycl_op_rope (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
263+ inline void ggml_sycl_op_rope (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
276264
277265 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
278266 GGML_ASSERT ( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -329,43 +317,46 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
329317 if (is_neox) {
330318 GGML_SYCL_DEBUG (" %s: neox path\n " , __func__);
331319 if (dst->src [0 ]->type == GGML_TYPE_F32) {
332- rope_neox_sycl (
333- (const float *)dst->src [0 ]->data , (float *)dst->data , ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
334- attn_factor, corr_dims, freq_factors, main_stream
335- );
320+ rope_neox_sycl ((const float *) dst->src [0 ]->data , (float *) dst->data , ne00, ne01, s01, s02, n_dims, nr,
321+ pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
336322 } else if (dst->src [0 ]->type == GGML_TYPE_F16) {
337- rope_neox_sycl (
338- (const sycl::half *)dst->src [0 ]->data , (sycl::half *)dst->data , ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
339- attn_factor, corr_dims, freq_factors, main_stream
340- );
323+ rope_neox_sycl ((const sycl::half *) dst->src [0 ]->data , (sycl::half *) dst->data , ne00, ne01, s01, s02,
324+ n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
325+ main_stream);
341326 } else {
342327 GGML_ABORT (" fatal error" );
343328 }
344329 } else if (is_vision) {
345330 GGML_SYCL_DEBUG (" %s: vision path\n " , __func__);
346331 if (dst->src [0 ]->type == GGML_TYPE_F16) {
347- rope_vision_sycl ((const sycl::half *)dst->src [0 ]->data , (sycl::half *)dst->data , ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
348- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
332+ rope_vision_sycl ((const sycl::half *) dst->src [0 ]->data , (sycl::half *) dst->data , ne00, ne01, ne02, s01,
333+ s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
334+ freq_factors, sections, main_stream);
349335 } else if (dst->src [0 ]->type == GGML_TYPE_F32) {
350- rope_vision_sycl ((const float *) dst->src [0 ]->data , (float *)dst->data , ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
351- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
336+ rope_vision_sycl ((const float *) dst->src [0 ]->data , (float *) dst->data , ne00, ne01, ne02, s01, s02, n_dims,
337+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
338+ main_stream);
352339 } else {
353340 GGML_ABORT (" Fatal error: Tensor type unsupported!" );
354341 }
355342 } else {
356343 GGML_SYCL_DEBUG (" %s: norm path\n " , __func__);
357344 if (dst->src [0 ]->type == GGML_TYPE_F32) {
358- rope_norm_sycl (
359- (const float *)dst->src [0 ]->data , (float *)dst->data , ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
360- attn_factor, corr_dims, freq_factors, main_stream
361- );
345+ rope_norm_sycl ((const float *) dst->src [0 ]->data , (float *) dst->data , ne00, ne01, s01, s02, n_dims, nr,
346+ pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
362347 } else if (dst->src [0 ]->type == GGML_TYPE_F16) {
363- rope_norm_sycl (
364- (const sycl::half *)dst->src [0 ]->data , (sycl::half *)dst->data , ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
365- attn_factor, corr_dims, freq_factors, main_stream
366- );
348+ rope_norm_sycl ((const sycl::half *) dst->src [0 ]->data , (sycl::half *) dst->data , ne00, ne01, s01, s02,
349+ n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
350+ main_stream);
367351 } else {
368352 GGML_ABORT (" fatal error" );
369353 }
370354 }
371355}
356+
357+ void ggml_sycl_rope (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
358+ GGML_SYCL_DEBUG (" call %s\n " , __func__);
359+ ggml_sycl_op_rope (ctx, dst);
360+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
361+ }
362+
0 commit comments