@@ -34,90 +34,92 @@ static void rope_yarn(
3434    *sin_theta = sycl::sin (theta) * mscale;
3535}
3636
37- template <typename  T, bool  has_ff>
38- static  void  rope_norm (
39-     const  T * x, T * dst, int  ne0, int  n_dims, const  int32_t  * pos, float  freq_scale, int  p_delta_rows,
40-     float  ext_factor, float  attn_factor, rope_corr_dims corr_dims, float  theta_scale, const  float  * freq_factors,
41-     const  sycl::nd_item<3 > &item_ct1) {
42-     const  int  i0 = 2  * (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) +
43-                          item_ct1.get_local_id (1 ));
37+ template  <typename  T, bool  has_ff>
38+ static  void  rope_norm (const  T * x, T * dst, const  int  ne0, const  int  ne1, const  int  s1, const  int  s2, const  int  n_dims,
39+                       const  int32_t  * pos, float  freq_scale, float  ext_factor, float  attn_factor,
40+                       const  rope_corr_dims corr_dims, const  float  theta_scale, const  float  * freq_factors,
41+                       const  sycl::nd_item<3 > & item_ct1) {
42+     const  int  i0 = 2  * (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) + item_ct1.get_local_id (1 ));
4443
4544    if  (i0 >= ne0) {
4645        return ;
4746    }
4847
49-     const  int  row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
50-                     item_ct1.get_local_id (2 );
48+     const  int  row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) + item_ct1.get_local_id (2 );
5149
5250    if  (i0 >= n_dims) {
53-         const  int  i = row* ne0 + i0;
51+         const  int  i = row *  ne0 + i0;
5452
5553        dst[i + 0 ] = x[i + 0 ];
5654        dst[i + 1 ] = x[i + 1 ];
5755
5856        return ;
5957    }
6058
61-     const  int  i = row*ne0 + i0;
62-     const  int  i2 = row/p_delta_rows;
59+     const  int  row0     = row % ne1;
60+     const  int  channel0 = row / ne1;
61+ 
62+     const  int  i  = row * ne0 + i0;
63+     const  int  i2 = channel0 * s2 + row0 * s1 + i0;
6364
64-     const  float  theta_base = pos[i2 ] * sycl::pow (theta_scale, i0 / 2 .0f );
65+     const  float  theta_base = pos[channel0 ] * sycl::pow (theta_scale, i0 / 2 .0f );
6566
66-     const  float  freq_factor = has_ff ? freq_factors[i0/ 2 ] : 1 .0f ;
67+     const  float  freq_factor = has_ff ? freq_factors[i0 /  2 ] : 1 .0f ;
6768
6869    float  cos_theta;
6970    float  sin_theta;
7071
71-     rope_yarn (theta_base/ freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
72+     rope_yarn (theta_base /  freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
7273
73-     const  float  x0 = x[i  + 0 ];
74-     const  float  x1 = x[i  + 1 ];
74+     const  float  x0 = x[i2  + 0 ];
75+     const  float  x1 = x[i2  + 1 ];
7576
76-     dst[i + 0 ] = x0* cos_theta - x1* sin_theta;
77-     dst[i + 1 ] = x0* sin_theta + x1* cos_theta;
77+     dst[i + 0 ] = x0 *  cos_theta - x1 *  sin_theta;
78+     dst[i + 1 ] = x0 *  sin_theta + x1 *  cos_theta;
7879}
7980
80- template <typename  T, bool  has_ff>
81- static  void  rope_neox (
82-     const  T * x, T * dst, int  ne0, int  n_dims, const  int32_t  * pos, float  freq_scale, int  p_delta_rows,
83-     float  ext_factor, float  attn_factor, rope_corr_dims corr_dims, float  theta_scale, const  float  * freq_factors,
84-     const  sycl::nd_item<3 > &item_ct1) {
85-     const  int  i0 = 2  * (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) +
86-                          item_ct1.get_local_id (1 ));
81+ template  <typename  T, bool  has_ff>
82+ static  void  rope_neox (const  T * x, T * dst, const  int  ne0, const  int  ne1, const  int  s1, const  int  s2, const  int  n_dims,
83+                       const  int32_t  * pos, const  float  freq_scale, const  float  ext_factor, const  float  attn_factor,
84+                       const  rope_corr_dims corr_dims, const  float  theta_scale, const  float  * freq_factors,
85+                       const  sycl::nd_item<3 > & item_ct1) {
86+     const  int  i0 = 2  * (item_ct1.get_local_range (1 ) * item_ct1.get_group (1 ) + item_ct1.get_local_id (1 ));
8787
8888    if  (i0 >= ne0) {
8989        return ;
9090    }
9191
92-     const  int  row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
93-                     item_ct1.get_local_id (2 );
92+     const  int  row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) + item_ct1.get_local_id (2 );
9493
9594    if  (i0 >= n_dims) {
96-         const  int  i = row* ne0 + i0;
95+         const  int  i = row *  ne0 + i0;
9796
9897        dst[i + 0 ] = x[i + 0 ];
9998        dst[i + 1 ] = x[i + 1 ];
10099
101100        return ;
102101    }
103102
104-     const  int  i  = row*ne0 + i0/2 ;
105-     const  int  i2 = row/p_delta_rows;
103+     const  int  row0     = row % ne1;
104+     const  int  channel0 = row / ne1;
105+ 
106+     const  int  i  = row * ne0 + i0 / 2 ;
107+     const  int  i2 = channel0 * s2 + row0 * s1 + i0 / 2 ;
106108
107-     const  float  theta_base = pos[i2 ] * sycl::pow (theta_scale, i0 / 2 .0f );
109+     const  float  theta_base = pos[channel0 ] * sycl::pow (theta_scale, i0 / 2 .0f );
108110
109-     const  float  freq_factor = has_ff ? freq_factors[i0/ 2 ] : 1 .0f ;
111+     const  float  freq_factor = has_ff ? freq_factors[i0 /  2 ] : 1 .0f ;
110112
111113    float  cos_theta;
112114    float  sin_theta;
113115
114-     rope_yarn (theta_base/ freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
116+     rope_yarn (theta_base /  freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
115117
116-     const  float  x0 = x[i  + 0 ];
117-     const  float  x1 = x[i  + n_dims/ 2 ];
118+     const  float  x0 = x[i2  + 0 ];
119+     const  float  x1 = x[i2  + n_dims /  2 ];
118120
119-     dst[i + 0 ]        = x0* cos_theta - x1* sin_theta;
120-     dst[i + n_dims/ 2 ] = x0* sin_theta + x1* cos_theta;
121+     dst[i + 0 ]           = x0 *  cos_theta - x1 *  sin_theta;
122+     dst[i + n_dims /  2 ] = x0 *  sin_theta + x1 *  cos_theta;
121123}
122124
123125template  <typename  T, bool  has_ff>
@@ -163,80 +165,66 @@ static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, cons
163165}
164166
165167template  <typename  T>
166- static  void  rope_norm_sycl (
167-     const  T *x, T *dst, int  ne0, int  n_dims, int  nr, const  int32_t  *pos, float  freq_scale, int  p_delta_rows,
168-     float  freq_base, float  ext_factor, float  attn_factor, rope_corr_dims corr_dims, const  float  * freq_factors, queue_ptr stream) {
168+ static  void  rope_norm_sycl (const  T * x, T * dst, const  int  ne0, const  int  ne1, const  int  s1, const  int  s2,
169+                            const  int  n_dims, int  nr, const  int32_t  * pos, const  float  freq_scale, const  float  freq_base,
170+                            const  float  ext_factor, const  float  attn_factor, const  rope_corr_dims corr_dims,
171+                            const  float  * freq_factors, queue_ptr stream) {
169172    GGML_ASSERT (ne0 % 2  == 0 );
170173    const  sycl::range<3 > block_dims (1 , SYCL_ROPE_BLOCK_SIZE, 1 );
171-     const  int  num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 * SYCL_ROPE_BLOCK_SIZE);
174+     const  int              num_blocks_x = (ne0 + 2  *  SYCL_ROPE_BLOCK_SIZE - 1 ) / (2  *  SYCL_ROPE_BLOCK_SIZE);
172175    const  sycl::range<3 > block_nums (1 , num_blocks_x, nr);
173176
174-     const  float  theta_scale = powf (freq_base, -2 .0f / n_dims);
177+     const  float  theta_scale = powf (freq_base, -2 .0f  /  n_dims);
175178
176-     dpct::has_capability_or_fail (stream->get_device (),
177-                                      {sycl::aspect::fp16});
179+     dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
178180
179181    if  (freq_factors == nullptr ) {
180182        /* 
181183        DPCT1049:40: The work-group size passed to the SYCL kernel may exceed 
182184        the limit. To get the device limit, query 
183185        info::device::max_work_group_size. Adjust the work-group size if needed. 
184186        */  
185-         stream->parallel_for (
186-             sycl::nd_range<3 >(block_nums * block_dims, block_dims),
187-             [=](sycl::nd_item<3 > item_ct1) {
188-                 rope_norm<T, false >(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
189-                                ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
190-                                item_ct1);
191-             });
187+         stream->parallel_for (sycl::nd_range<3 >(block_nums * block_dims, block_dims), [=](sycl::nd_item<3 > item_ct1) {
188+             rope_norm<T, false >(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
189+                                 theta_scale, freq_factors, item_ct1);
190+         });
192191    } else  {
193192        /* 
194193        DPCT1049:41: The work-group size passed to the SYCL kernel may exceed 
195194        the limit. To get the device limit, query 
196195        info::device::max_work_group_size. Adjust the work-group size if needed. 
197196        */  
198-         stream->parallel_for (
199-             sycl::nd_range<3 >(block_nums * block_dims, block_dims),
200-             [=](sycl::nd_item<3 > item_ct1) {
201-                 rope_norm<T, true >(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
202-                               ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
203-                               item_ct1);
204-             });
197+         stream->parallel_for (sycl::nd_range<3 >(block_nums * block_dims, block_dims), [=](sycl::nd_item<3 > item_ct1) {
198+             rope_norm<T, true >(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
199+                                theta_scale, freq_factors, item_ct1);
200+         });
205201    }
206202}
207203
208204template  <typename  T>
209- static  void  rope_neox_sycl (
210-     const  T *x, T *dst, int  ne0, int  n_dims, int  nr, const  int32_t  *pos, float  freq_scale, int  p_delta_rows,
211-     float  freq_base, float  ext_factor, float  attn_factor, rope_corr_dims corr_dims, const  float  * freq_factors, queue_ptr stream) {
205+ static  void  rope_neox_sycl (const  T * x, T * dst, const  int  ne0, const  int  ne1, const  int  s1, const  int  s2,
206+                            const  int  n_dims, const  int  nr, const  int32_t  * pos, const  float  freq_scale,
207+                            const  float  freq_base, const  float  ext_factor, const  float  attn_factor,
208+                            const  rope_corr_dims corr_dims, const  float  * freq_factors, queue_ptr stream) {
212209    GGML_ASSERT (ne0 % 2  == 0 );
213210    const  sycl::range<3 > block_dims (1 , SYCL_ROPE_BLOCK_SIZE, 1 );
214-     const  int  num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 * SYCL_ROPE_BLOCK_SIZE);
211+     const  int              num_blocks_x = (ne0 + 2  *  SYCL_ROPE_BLOCK_SIZE - 1 ) / (2  *  SYCL_ROPE_BLOCK_SIZE);
215212    const  sycl::range<3 > block_nums (1 , num_blocks_x, nr);
216213
217-     const  float  theta_scale = powf (freq_base, -2 .0f / n_dims);
214+     const  float  theta_scale = powf (freq_base, -2 .0f  /  n_dims);
218215
219-     dpct::has_capability_or_fail (stream->get_device (),
220-                                     {sycl::aspect::fp16});
216+     dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
221217
222218    if  (freq_factors == nullptr ) {
223-         stream->parallel_for (
224-             sycl::nd_range<3 >(block_nums * block_dims, block_dims),
225-             [=](sycl::nd_item<3 > item_ct1) {
226-                 rope_neox<T, false >(x, dst, ne0, n_dims, pos, freq_scale,
227-                                     p_delta_rows, ext_factor, attn_factor,
228-                                     corr_dims, theta_scale, freq_factors,
229-                                     item_ct1);
230-             });
219+         stream->parallel_for (sycl::nd_range<3 >(block_nums * block_dims, block_dims), [=](sycl::nd_item<3 > item_ct1) {
220+             rope_neox<T, false >(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
221+                                 theta_scale, freq_factors, item_ct1);
222+         });
231223    } else  {
232-         stream->parallel_for (
233-             sycl::nd_range<3 >(block_nums * block_dims, block_dims),
234-             [=](sycl::nd_item<3 > item_ct1) {
235-                 rope_neox<T, true >(x, dst, ne0, n_dims, pos, freq_scale,
236-                                     p_delta_rows, ext_factor, attn_factor,
237-                                     corr_dims, theta_scale, freq_factors,
238-                                     item_ct1);
239-             });
224+         stream->parallel_for (sycl::nd_range<3 >(block_nums * block_dims, block_dims), [=](sycl::nd_item<3 > item_ct1) {
225+             rope_neox<T, true >(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
226+                                theta_scale, freq_factors, item_ct1);
227+         });
240228    }
241229}
242230
@@ -272,7 +260,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
272260    }
273261}
274262
275- void  ggml_sycl_op_rope (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
263+ inline   void  ggml_sycl_op_rope (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
276264
277265    GGML_ASSERT (dst->src [0 ]->type  == GGML_TYPE_F32 || dst->src [0 ]->type  == GGML_TYPE_F16);
278266    GGML_ASSERT ( dst->type  == GGML_TYPE_F32 ||  dst->type  == GGML_TYPE_F16);
@@ -329,43 +317,46 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
329317    if  (is_neox) {
330318        GGML_SYCL_DEBUG (" %s: neox path\n "  , __func__);
331319        if  (dst->src [0 ]->type  == GGML_TYPE_F32) {
332-             rope_neox_sycl (
333-                 (const  float  *)dst->src [0 ]->data , (float  *)dst->data , ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
334-                 attn_factor, corr_dims, freq_factors, main_stream
335-             );
320+             rope_neox_sycl ((const  float  *) dst->src [0 ]->data , (float  *) dst->data , ne00, ne01, s01, s02, n_dims, nr,
321+                            pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
336322        } else  if  (dst->src [0 ]->type  == GGML_TYPE_F16) {
337-             rope_neox_sycl (
338-                 (const  sycl::half *)dst->src [0 ]->data , (sycl::half *)dst->data , ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
339-                 attn_factor, corr_dims, freq_factors, main_stream
340-             );
323+             rope_neox_sycl ((const  sycl::half *) dst->src [0 ]->data , (sycl::half *) dst->data , ne00, ne01, s01, s02,
324+                            n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
325+                            main_stream);
341326        } else  {
342327            GGML_ABORT (" fatal error"  );
343328        }
344329    } else  if  (is_vision) {
345330        GGML_SYCL_DEBUG (" %s: vision path\n "  , __func__);
346331        if  (dst->src [0 ]->type  == GGML_TYPE_F16) {
347-             rope_vision_sycl ((const  sycl::half *)dst->src [0 ]->data , (sycl::half *)dst->data , ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
348-                 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
332+             rope_vision_sycl ((const  sycl::half *) dst->src [0 ]->data , (sycl::half *) dst->data , ne00, ne01, ne02, s01,
333+                              s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
334+                              freq_factors, sections, main_stream);
349335        } else  if  (dst->src [0 ]->type  == GGML_TYPE_F32) {
350-             rope_vision_sycl ((const  float  *) dst->src [0 ]->data , (float  *)dst->data , ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
351-                 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
336+             rope_vision_sycl ((const  float  *) dst->src [0 ]->data , (float  *) dst->data , ne00, ne01, ne02, s01, s02, n_dims,
337+                              nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
338+                              main_stream);
352339        } else  {
353340            GGML_ABORT (" Fatal error: Tensor type unsupported!"  );
354341        }
355342    } else  {
356343        GGML_SYCL_DEBUG (" %s: norm path\n "  , __func__);
357344        if  (dst->src [0 ]->type  == GGML_TYPE_F32) {
358-             rope_norm_sycl (
359-                 (const  float  *)dst->src [0 ]->data , (float  *)dst->data , ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
360-                 attn_factor, corr_dims, freq_factors, main_stream
361-             );
345+             rope_norm_sycl ((const  float  *) dst->src [0 ]->data , (float  *) dst->data , ne00, ne01, s01, s02, n_dims, nr,
346+                            pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
362347        } else  if  (dst->src [0 ]->type  == GGML_TYPE_F16) {
363-             rope_norm_sycl (
364-                 (const  sycl::half *)dst->src [0 ]->data , (sycl::half *)dst->data , ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
365-                 attn_factor, corr_dims, freq_factors, main_stream
366-             );
348+             rope_norm_sycl ((const  sycl::half *) dst->src [0 ]->data , (sycl::half *) dst->data , ne00, ne01, s01, s02,
349+                            n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
350+                            main_stream);
367351        } else  {
368352            GGML_ABORT (" fatal error"  );
369353        }
370354    }
371355}
356+ 
357+ void  ggml_sycl_rope (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
358+     GGML_SYCL_DEBUG (" call %s\n "  , __func__);
359+     ggml_sycl_op_rope (ctx, dst);
360+     GGML_SYCL_DEBUG (" call %s done\n "  , __func__);
361+ }
362+ 
0 commit comments