11#include  " rope.hpp" 
2+ #include  " ggml-sycl/common.hpp" 
3+ #include  " ggml.h" 
24
35struct  rope_corr_dims  {
46    float  v[2 ];
57};
68
9+ struct  mrope_sections  {
10+     int  v[4 ];
11+ };
12+ 
713static  float  rope_yarn_ramp (const  float  low, const  float  high, const  int  i0) {
814    const  float  y = (i0 / 2  - low) / sycl::max (0 .001f , high - low);
915    return  1 .0f  - sycl::min (1 .0f , sycl::max (0 .0f , y));
@@ -114,6 +120,48 @@ static void rope_neox(
114120    dst[i + n_dims/2 ] = x0*sin_theta + x1*cos_theta;
115121}
116122
123+ template  <typename  T, bool  has_ff>
124+ static  void  rope_vision (const  T * x, T * dst, const  int  ne0, const  int  ne1, const  int  ne2, const  size_t  s1,
125+                         const  size_t  s2, const  int  n_dims, const  int32_t  * pos, const  float  freq_scale,
126+                         const  float  ext_factor, const  float  attn_factor, const  rope_corr_dims corr_dims,
127+                         const  float  theta_scale, const  float  * freq_factors, const  mrope_sections sections,
128+                         const  sycl::nd_item<3 > & item_ct1) {
129+     //  get index pos
130+     const  int  i0 = 2  * (item_ct1.get_group (1 ) * item_ct1.get_local_range (1 ) + item_ct1.get_local_id (1 ));
131+     if  (i0 >= ne0) {
132+         return ;
133+     }
134+     const  int     row_dst   = (item_ct1.get_group (2 ) * item_ct1.get_local_range (2 )) + item_ct1.get_local_id (2 );
135+     const  int     row_x     = row_dst % ne1;
136+     const  int     channel_x = row_dst / ne1;
137+     const  int     idst      = (row_dst * ne0) + (i0 / 2 );
138+     const  size_t  ix        = ((size_t ) channel_x * s2) + ((size_t ) row_x * s1) + (i0 / 2 );
139+ 
140+     const  int  sect_dims = sections.v [0 ] + sections.v [1 ];
141+     const  int  sector    = (i0 / 2 ) % sect_dims;
142+ 
143+     float  theta_base = 0 .0f ;
144+     if  (sector < sections.v [0 ]) {
145+         const  int  p = sector;
146+         theta_base  = pos[channel_x] * sycl::pow (theta_scale, (float ) p);
147+     } else  {
148+         //  Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0]
149+         const  int  p = sector - sections.v [0 ];
150+         theta_base  = pos[channel_x + ne2] * sycl::pow (theta_scale, (float ) p);
151+     }
152+ 
153+     const  float  freq_factor = has_ff ? freq_factors[i0 / 2 ] : 1 .0f ;
154+     float        cos_theta;
155+     float        sin_theta;
156+     rope_yarn (theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
157+     const  float  x0 = x[ix + 0 ];
158+     const  float  x1 = x[ix + n_dims];
159+ 
160+     //  store results in dst
161+     dst[idst + 0 ]      = x0 * cos_theta - x1 * sin_theta;
162+     dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
163+ }
164+ 
117165template  <typename  T>
118166static  void  rope_norm_sycl (
119167    const  T *x, T *dst, int  ne0, int  n_dims, int  nr, const  int32_t  *pos, float  freq_scale, int  p_delta_rows,
@@ -192,21 +240,58 @@ static void rope_neox_sycl(
192240    }
193241}
194242
243+ //  rope vision
244+ template  <typename  T>
245+ static  void  rope_vision_sycl (const  T * x, T * dst, const  int  ne0, const  int  ne1, const  int  ne2, const  size_t  s1,
246+                              const  size_t  s2, const  int  n_dims, const  int  nr, const  int32_t  * pos,
247+                              const  float  freq_scale, const  float  freq_base, const  float  ext_factor,
248+                              const  float  attn_factor, const  rope_corr_dims corr_dims, const  float  * freq_factors,
249+                              const  mrope_sections sections, queue_ptr stream) {
250+     GGML_ASSERT (ne0 % 2  == 0 );
251+     const  sycl::range<3 >    block_dims (1 , SYCL_ROPE_BLOCK_SIZE, 1 );
252+     const  int                n_blocks_y = (ne0 + 2  * SYCL_ROPE_BLOCK_SIZE - 1 ) / (2  * SYCL_ROPE_BLOCK_SIZE);
253+     const  sycl::range<3 >    grid_dims (1 , n_blocks_y, nr);
254+     const  sycl::nd_range<3 > nd_range (grid_dims * block_dims, block_dims);
255+ 
256+     const  float  theta_scale = std::pow (freq_base, -2 .0f  / n_dims);
257+     //  Add FP16 capability check if T could be sycl::half
258+     if  constexpr  (std::is_same_v<T, sycl::half>) {
259+         dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
260+     }
261+     //  launch kernel
262+     if  (freq_factors == nullptr ) {
263+         stream->parallel_for (nd_range, [=](sycl::nd_item<3 > item_ct1) {
264+             rope_vision<T, false >(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
265+                                   corr_dims, theta_scale, freq_factors, sections, item_ct1);
266+         });
267+     } else  {
268+         stream->parallel_for (nd_range, [=](sycl::nd_item<3 > item_ct1) {
269+             rope_vision<T, true >(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
270+                                  corr_dims, theta_scale, freq_factors, sections, item_ct1);
271+         });
272+     }
273+ }
274+ 
195275void  ggml_sycl_op_rope (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
196276
197277    GGML_ASSERT (dst->src [0 ]->type  == GGML_TYPE_F32 || dst->src [0 ]->type  == GGML_TYPE_F16);
198278    GGML_ASSERT ( dst->type  == GGML_TYPE_F32 ||  dst->type  == GGML_TYPE_F16);
199279    GGML_ASSERT (dst->src [0 ]->type  == dst->type );
200- 
201-     const  int64_t  ne00  = dst->src [0 ]->ne [0 ]; 
202-     const  int64_t  ne01  = dst->src [0 ]->ne [1 ]; 
280+      const   int64_t  ne00 = dst-> src [ 0 ]-> ne [ 0 ];  //  head dims 
281+     const  int64_t  ne01  = dst->src [0 ]->ne [1 ];  //  num heads 
282+     const  int64_t  ne02  = dst->src [0 ]->ne [2 ];  //  num heads 
203283    const  int64_t  nr = ggml_nrows (dst->src [0 ]);
204284
285+     const  size_t  s01 = dst->src [0 ]->nb [1 ] / ggml_type_size (dst->src [0 ]->type );
286+     const  size_t  s02 = dst->src [0 ]->nb [2 ] / ggml_type_size (dst->src [0 ]->type );
287+ 
288+ 
205289    // const int n_past      = ((int32_t *) dst->op_params)[0];
206290    const  int  n_dims      = ((int32_t  *) dst->op_params )[1 ];
207291    const  int  mode        = ((int32_t  *) dst->op_params )[2 ];
208292    // const int n_ctx       = ((int32_t *) dst->op_params)[3];
209293    const  int  n_ctx_orig  = ((int32_t  *) dst->op_params )[4 ];
294+     mrope_sections sections;
210295
211296    //  RoPE alteration for extended context
212297    float  freq_base;
@@ -222,8 +307,10 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
222307    memcpy (&attn_factor, (int32_t  *) dst->op_params  +  8 , sizeof (float ));
223308    memcpy (&beta_fast,   (int32_t  *) dst->op_params  +  9 , sizeof (float ));
224309    memcpy (&beta_slow,   (int32_t  *) dst->op_params  + 10 , sizeof (float ));
310+     memcpy (§ions.v ,  (int32_t  *) dst->op_params  + 11 , sizeof (int )*4 );
225311
226312    const  bool  is_neox = mode & GGML_ROPE_TYPE_NEOX;
313+     const  bool  is_vision = mode == GGML_ROPE_TYPE_VISION;
227314
228315    const  int32_t  * pos = (const  int32_t  *) dst->src [1 ]->data ;
229316
@@ -240,6 +327,7 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
240327
241328    //  compute
242329    if  (is_neox) {
330+         GGML_SYCL_DEBUG (" %s: neox path\n "  , __func__);
243331        if  (dst->src [0 ]->type  == GGML_TYPE_F32) {
244332            rope_neox_sycl (
245333                (const  float  *)dst->src [0 ]->data , (float  *)dst->data , ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
@@ -253,7 +341,19 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
253341        } else  {
254342            GGML_ABORT (" fatal error"  );
255343        }
344+     } else  if  (is_vision) {
345+         GGML_SYCL_DEBUG (" %s: vision path\n "  , __func__);
346+         if  (dst->src [0 ]->type  == GGML_TYPE_F16) {
347+             rope_vision_sycl ((const  sycl::half *)dst->src [0 ]->data , (sycl::half *)dst->data , ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
348+                 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
349+         } else  if  (dst->src [0 ]->type  == GGML_TYPE_F32) {
350+             rope_vision_sycl ((const  float  *) dst->src [0 ]->data , (float  *)dst->data , ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
351+                 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
352+         } else  {
353+             GGML_ABORT (" Fatal error: Tensor type unsupported!"  );
354+         }
256355    } else  {
356+         GGML_SYCL_DEBUG (" %s: norm path\n "  , __func__);
257357        if  (dst->src [0 ]->type  == GGML_TYPE_F32) {
258358            rope_norm_sycl (
259359                (const  float  *)dst->src [0 ]->data , (float  *)dst->data , ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
0 commit comments