@@ -37,11 +37,23 @@ static __device__ void rope_yarn(
3737    }
3838}
3939
40- template <bool  forward, bool  has_ff, typename  T>
41- static  __global__  void  rope_norm (
42-         const  T * x, T * dst, const  int  ne0, const  int  ne1, const  int  s1, const  int  s2, const  int  n_dims,
43-         const  int32_t  * pos, const  float  freq_scale, const  float  ext_factor, const  float  attn_factor,
44-         const  rope_corr_dims corr_dims, const  float  theta_scale, const  float  * freq_factors) {
40+ template  <bool  forward, bool  has_ff, typename  T, typename  D>
41+ static  __global__  void  rope_norm (const  T *            x,
42+                                  D *                  dst,
43+                                  const  int             ne0,
44+                                  const  int             ne1,
45+                                  const  int             s1,
46+                                  const  int             s2,
47+                                  const  int             n_dims,
48+                                  const  int32_t  *      pos,
49+                                  const  float           freq_scale,
50+                                  const  float           ext_factor,
51+                                  const  float           attn_factor,
52+                                  const  rope_corr_dims corr_dims,
53+                                  const  float           theta_scale,
54+                                  const  float  *        freq_factors,
55+                                  const  int64_t  *      row_indices,
56+                                  const  int             set_rows_stride) {
4557    const  int  i0 = 2 *(blockDim .y *blockIdx .y  + threadIdx .y );
4658
4759    if  (i0 >= ne0) {
@@ -53,12 +65,19 @@ static __global__ void rope_norm(
5365    const  int  row_x     = row_dst % ne1;
5466    const  int  channel_x = row_dst / ne1;
5567
56-     const   int  idst = row_dst* ne0 + i0;
68+     int         idst = row_dst *  ne0 + i0;
5769    const  int  ix   = channel_x*s2 + row_x*s1 + i0;
5870
71+     //  Fusion optimization: ROPE + VIEW + SET_ROWS.
72+     //  The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
73+     if  (set_rows_stride != 0 ) {
74+         idst = row_x * ne0 + i0;
75+         idst += row_indices[channel_x] * set_rows_stride;
76+     }
77+ 
5978    if  (i0 >= n_dims) {
60-         dst[idst + 0 ] = x[ix + 0 ];
61-         dst[idst + 1 ] = x[ix + 1 ];
79+         dst[idst + 0 ] = D ( x[ix + 0 ]) ;
80+         dst[idst + 1 ] = D ( x[ix + 1 ]) ;
6281
6382        return ;
6483    }
@@ -75,15 +94,27 @@ static __global__ void rope_norm(
7594    const  float  x0 = x[ix + 0 ];
7695    const  float  x1 = x[ix + 1 ];
7796
78-     dst[idst + 0 ] = x0* cos_theta - x1* sin_theta;
79-     dst[idst + 1 ] = x0* sin_theta + x1* cos_theta;
97+     dst[idst + 0 ] = D (x0 *  cos_theta - x1 *  sin_theta) ;
98+     dst[idst + 1 ] = D (x0 *  sin_theta + x1 *  cos_theta) ;
8099}
81100
82- template <bool  forward, bool  has_ff, typename  T>
83- static  __global__  void  rope_neox (
84-         const  T * x, T * dst, const  int  ne0, const  int  ne1, const  int  s1, const  int  s2, const  int  n_dims,
85-         const  int32_t  * pos, const  float  freq_scale, const  float  ext_factor, const  float  attn_factor,
86-         const  rope_corr_dims corr_dims, const  float  theta_scale, const  float  * freq_factors) {
101+ template  <bool  forward, bool  has_ff, typename  T, typename  D>
102+ static  __global__  void  rope_neox (const  T *            x,
103+                                  D *                  dst,
104+                                  const  int             ne0,
105+                                  const  int             ne1,
106+                                  const  int             s1,
107+                                  const  int             s2,
108+                                  const  int             n_dims,
109+                                  const  int32_t  *      pos,
110+                                  const  float           freq_scale,
111+                                  const  float           ext_factor,
112+                                  const  float           attn_factor,
113+                                  const  rope_corr_dims corr_dims,
114+                                  const  float           theta_scale,
115+                                  const  float  *        freq_factors,
116+                                  const  int64_t  *      row_indices,
117+                                  const  int             set_rows_stride) {
87118    const  int  i0 = 2 *(blockDim .y *blockIdx .y  + threadIdx .y );
88119
89120    if  (i0 >= ne0) {
@@ -95,12 +126,19 @@ static __global__ void rope_neox(
95126    const  int  row_x     = row_dst % ne1;
96127    const  int  channel_x = row_dst / ne1;
97128
98-     const   int  idst = row_dst* ne0 + i0/ 2 ;
129+     int         idst = row_dst *  ne0 + i0 /  2 ;
99130    const  int  ix   = channel_x*s2 + row_x*s1 + i0/2 ;
100131
132+     //  Fusion optimization: ROPE + VIEW + SET_ROWS.
133+     //  The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
134+     if  (set_rows_stride != 0 ) {
135+         idst = row_x * ne0 + i0 / 2 ;
136+         idst += row_indices[channel_x] * set_rows_stride;
137+     }
138+ 
101139    if  (i0 >= n_dims) {
102-         dst[idst + i0/ 2  + 0 ] = x[ix + i0/ 2  + 0 ];
103-         dst[idst + i0/ 2  + 1 ] = x[ix + i0/ 2  + 1 ];
140+         dst[idst + i0 /  2  + 0 ] = D ( x[ix + i0 /  2  + 0 ]) ;
141+         dst[idst + i0 /  2  + 1 ] = D ( x[ix + i0 /  2  + 1 ]) ;
104142
105143        return ;
106144    }
@@ -117,8 +155,8 @@ static __global__ void rope_neox(
117155    const  float  x0 = x[ix + 0 ];
118156    const  float  x1 = x[ix + n_dims/2 ];
119157
120-     dst[idst + 0 ]        = x0* cos_theta - x1* sin_theta;
121-     dst[idst + n_dims/ 2 ] = x0* sin_theta + x1* cos_theta;
158+     dst[idst + 0 ]          =  D (x0 *  cos_theta - x1 *  sin_theta) ;
159+     dst[idst + n_dims /  2 ] = D (x0 *  sin_theta + x1 *  cos_theta) ;
122160}
123161
124162template <bool  forward, bool  has_ff, typename  T>
@@ -226,11 +264,25 @@ static __global__ void rope_vision(
226264    dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
227265}
228266
229- template <bool  forward, typename  T>
230- static  void  rope_norm_cuda (
231-         const  T * x, T * dst, const  int  ne0, const  int  ne1, const  int  s1, const  int  s2, const  int  n_dims, const  int  nr,
232-         const  int32_t  * pos, const  float  freq_scale, const  float  freq_base, const  float  ext_factor, const  float  attn_factor,
233-         const  rope_corr_dims corr_dims, const  float  * freq_factors, cudaStream_t stream) {
267+ template  <bool  forward, typename  T, typename  D>
268+ static  void  rope_norm_cuda (const  T *            x,
269+                            D *                  dst,
270+                            const  int             ne0,
271+                            const  int             ne1,
272+                            const  int             s1,
273+                            const  int             s2,
274+                            const  int             n_dims,
275+                            const  int             nr,
276+                            const  int32_t  *      pos,
277+                            const  float           freq_scale,
278+                            const  float           freq_base,
279+                            const  float           ext_factor,
280+                            const  float           attn_factor,
281+                            const  rope_corr_dims corr_dims,
282+                            const  float  *        freq_factors,
283+                            const  int64_t  *      row_indices,
284+                            const  int             set_rows_stride,
285+                            cudaStream_t         stream) {
234286    GGML_ASSERT (ne0 % 2  == 0 );
235287    const  dim3  block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
236288    const  int  n_blocks_x = (ne0 + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
@@ -240,20 +292,34 @@ static void rope_norm_cuda(
240292
241293    if  (freq_factors == nullptr ) {
242294        rope_norm<forward, false ><<<block_nums, block_dims, 0 , stream>>> (
243-             x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
244-             attn_factor, corr_dims, theta_scale, freq_factors );
295+             x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, 
296+             freq_factors, row_indices, set_rows_stride );
245297    } else  {
246298        rope_norm<forward, true ><<<block_nums, block_dims, 0 , stream>>> (
247-             x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
248-             attn_factor, corr_dims, theta_scale, freq_factors );
299+             x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, 
300+             freq_factors, row_indices, set_rows_stride );
249301    }
250302}
251303
252- template <bool  forward, typename  T>
253- static  void  rope_neox_cuda (
254-         const  T * x, T * dst, const  int  ne0, const  int  ne1, const  int  s1, const  int  s2, const  int  n_dims, const  int  nr,
255-         const  int32_t  * pos, const  float  freq_scale, const  float  freq_base, const  float  ext_factor, const  float  attn_factor,
256-         const  rope_corr_dims corr_dims, const  float  * freq_factors, cudaStream_t stream) {
304+ template  <bool  forward, typename  T, typename  D>
305+ static  void  rope_neox_cuda (const  T *            x,
306+                            D *                  dst,
307+                            const  int             ne0,
308+                            const  int             ne1,
309+                            const  int             s1,
310+                            const  int             s2,
311+                            const  int             n_dims,
312+                            const  int             nr,
313+                            const  int32_t  *      pos,
314+                            const  float           freq_scale,
315+                            const  float           freq_base,
316+                            const  float           ext_factor,
317+                            const  float           attn_factor,
318+                            const  rope_corr_dims corr_dims,
319+                            const  float  *        freq_factors,
320+                            const  int64_t  *      row_indices,
321+                            const  int             set_rows_stride,
322+                            cudaStream_t         stream) {
257323    GGML_ASSERT (ne0 % 2  == 0 );
258324    const  dim3  block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
259325    const  int  n_blocks_x = (ne0 + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
@@ -262,13 +328,13 @@ static void rope_neox_cuda(
262328    const  float  theta_scale = powf (freq_base, -2 .0f /n_dims);
263329
264330    if  (freq_factors == nullptr ) {
265-         rope_neox<forward, false , T ><<<block_nums, block_dims, 0 , stream>>> (
266-             x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
267-             attn_factor, corr_dims, theta_scale, freq_factors );
331+         rope_neox<forward, false ><<<block_nums, block_dims, 0 , stream>>> (
332+             x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, 
333+             freq_factors, row_indices, set_rows_stride );
268334    } else  {
269-         rope_neox<forward, true , T ><<<block_nums, block_dims, 0 , stream>>> (
270-             x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
271-             attn_factor, corr_dims, theta_scale, freq_factors );
335+         rope_neox<forward, true ><<<block_nums, block_dims, 0 , stream>>> (
336+             x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, 
337+             freq_factors, row_indices, set_rows_stride );
272338    }
273339}
274340
@@ -325,6 +391,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
325391    const  ggml_tensor * src0 = dst->src [0 ];
326392    const  ggml_tensor * src1 = dst->src [1 ];
327393    const  ggml_tensor * src2 = dst->src [2 ];
394+     const  ggml_tensor * src3 = dst->src [3 ];
328395
329396    const  float  * src0_d = (const  float  *)src0->data ;
330397    const  float  * src1_d = (const  float  *)src1->data ;
@@ -334,7 +401,9 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
334401
335402    GGML_ASSERT (src0->type  == GGML_TYPE_F32 || src0->type  == GGML_TYPE_F16);
336403    GGML_ASSERT ( dst->type  == GGML_TYPE_F32 ||  dst->type  == GGML_TYPE_F16);
337-     GGML_ASSERT (src0->type  == dst->type );
404+     //  When not fused, src0 and dst types must match
405+     //  When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16
406+     GGML_ASSERT (src0->type  == dst->type  || dst->type  == GGML_TYPE_F16);
338407
339408    const  int64_t  ne00 = src0->ne [0 ]; //  head dims
340409    const  int64_t  ne01 = src0->ne [1 ]; //  num heads
@@ -386,19 +455,32 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
386455        freq_factors = (const  float  *) src2->data ;
387456    }
388457
458+     //  Row indices for fused ROPE + VIEW + SET_ROWS
459+     const  int64_t  * row_indices     = nullptr ;
460+     int              set_rows_stride = 0 ;
461+     if  (src3 != nullptr ) {
462+         GGML_ASSERT (src3->type  == GGML_TYPE_I64);
463+         row_indices     = (const  int64_t  *) src3->data ;
464+         set_rows_stride = ggml_get_op_params_i32 (dst, 15 );
465+     }
466+ 
389467    rope_corr_dims corr_dims;
390468    ggml_rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v );
391469
392470    //  compute
393471    if  (is_neox) {
394-         if  (src0->type  == GGML_TYPE_F32) {
395-             rope_neox_cuda<forward>(
396-                 (const  float  *) src0_d, (float  *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
397-                 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
398-         } else  if  (src0->type  == GGML_TYPE_F16) {
399-             rope_neox_cuda<forward>(
400-                 (const  half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
401-                 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
472+         if  (src0->type  == GGML_TYPE_F32 && dst->type  == GGML_TYPE_F32) {
473+             rope_neox_cuda<forward, float , float >((const  float  *) src0_d, (float  *) dst_d, ne00, ne01, s01, s02, n_dims,
474+                                                   nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
475+                                                   freq_factors, row_indices, set_rows_stride, stream);
476+         } else  if  (src0->type  == GGML_TYPE_F32 && dst->type  == GGML_TYPE_F16) {
477+             rope_neox_cuda<forward, float , half>((const  float  *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
478+                                                  nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
479+                                                  freq_factors, row_indices, set_rows_stride, stream);
480+         } else  if  (src0->type  == GGML_TYPE_F16 && dst->type  == GGML_TYPE_F16) {
481+             rope_neox_cuda<forward, half, half>((const  half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
482+                                                 pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
483+                                                 freq_factors, row_indices, set_rows_stride, stream);
402484        } else  {
403485            GGML_ABORT (" fatal error"  );
404486        }
@@ -427,14 +509,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
427509            GGML_ABORT (" fatal error"  );
428510        }
429511    } else  {
430-         if  (src0->type  == GGML_TYPE_F32) {
431-             rope_norm_cuda<forward>(
432-                 (const  float  *) src0_d, (float  *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
433-                 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
434-         } else  if  (src0->type  == GGML_TYPE_F16) {
435-             rope_norm_cuda<forward>(
436-                 (const  half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
437-                 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
512+         if  (src0->type  == GGML_TYPE_F32 && dst->type  == GGML_TYPE_F32) {
513+             rope_norm_cuda<forward, float , float >((const  float  *) src0_d, (float  *) dst_d, ne00, ne01, s01, s02, n_dims,
514+                                                   nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
515+                                                   freq_factors, row_indices, set_rows_stride, stream);
516+         } else  if  (src0->type  == GGML_TYPE_F32 && dst->type  == GGML_TYPE_F16) {
517+             rope_norm_cuda<forward, float , half>((const  float  *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
518+                                                  nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
519+                                                  freq_factors, row_indices, set_rows_stride, stream);
520+         } else  if  (src0->type  == GGML_TYPE_F16 && dst->type  == GGML_TYPE_F16) {
521+             rope_norm_cuda<forward, half, half>((const  half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
522+                                                 pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
523+                                                 freq_factors, row_indices, set_rows_stride, stream);
438524        } else  {
439525            GGML_ABORT (" fatal error"  );
440526        }
0 commit comments