@@ -37,11 +37,23 @@ static __device__ void rope_yarn(
3737 }
3838}
3939
40- template <bool forward, bool has_ff, typename T>
41- static __global__ void rope_norm (
42- const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
43- const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
44- const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
40+ template <bool forward, bool has_ff, typename T, typename D>
41+ static __global__ void rope_norm (const T * x,
42+ D * dst,
43+ const int ne0,
44+ const int ne1,
45+ const int s1,
46+ const int s2,
47+ const int n_dims,
48+ const int32_t * pos,
49+ const float freq_scale,
50+ const float ext_factor,
51+ const float attn_factor,
52+ const rope_corr_dims corr_dims,
53+ const float theta_scale,
54+ const float * freq_factors,
55+ const int64_t * row_indices,
56+ const int set_rows_stride) {
4557 const int i0 = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
4658
4759 if (i0 >= ne0) {
@@ -53,12 +65,19 @@ static __global__ void rope_norm(
5365 const int row_x = row_dst % ne1;
5466 const int channel_x = row_dst / ne1;
5567
56- const int idst = row_dst* ne0 + i0;
68+ int idst = row_dst * ne0 + i0;
5769 const int ix = channel_x*s2 + row_x*s1 + i0;
5870
71+ // Fusion optimization: ROPE + VIEW + SET_ROWS.
72+ // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
73+ if (set_rows_stride != 0 ) {
74+ idst = row_x * ne0 + i0;
75+ idst += row_indices[channel_x] * set_rows_stride;
76+ }
77+
5978 if (i0 >= n_dims) {
60- dst[idst + 0 ] = x[ix + 0 ];
61- dst[idst + 1 ] = x[ix + 1 ];
79+ dst[idst + 0 ] = D ( x[ix + 0 ]) ;
80+ dst[idst + 1 ] = D ( x[ix + 1 ]) ;
6281
6382 return ;
6483 }
@@ -75,15 +94,27 @@ static __global__ void rope_norm(
7594 const float x0 = x[ix + 0 ];
7695 const float x1 = x[ix + 1 ];
7796
78- dst[idst + 0 ] = x0* cos_theta - x1* sin_theta;
79- dst[idst + 1 ] = x0* sin_theta + x1* cos_theta;
97+ dst[idst + 0 ] = D (x0 * cos_theta - x1 * sin_theta) ;
98+ dst[idst + 1 ] = D (x0 * sin_theta + x1 * cos_theta) ;
8099}
81100
82- template <bool forward, bool has_ff, typename T>
83- static __global__ void rope_neox (
84- const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
85- const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
86- const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
101+ template <bool forward, bool has_ff, typename T, typename D>
102+ static __global__ void rope_neox (const T * x,
103+ D * dst,
104+ const int ne0,
105+ const int ne1,
106+ const int s1,
107+ const int s2,
108+ const int n_dims,
109+ const int32_t * pos,
110+ const float freq_scale,
111+ const float ext_factor,
112+ const float attn_factor,
113+ const rope_corr_dims corr_dims,
114+ const float theta_scale,
115+ const float * freq_factors,
116+ const int64_t * row_indices,
117+ const int set_rows_stride) {
87118 const int i0 = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
88119
89120 if (i0 >= ne0) {
@@ -95,12 +126,19 @@ static __global__ void rope_neox(
95126 const int row_x = row_dst % ne1;
96127 const int channel_x = row_dst / ne1;
97128
98- const int idst = row_dst* ne0 + i0/ 2 ;
129+ int idst = row_dst * ne0 + i0 / 2 ;
99130 const int ix = channel_x*s2 + row_x*s1 + i0/2 ;
100131
132+ // Fusion optimization: ROPE + VIEW + SET_ROWS.
133+ // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
134+ if (set_rows_stride != 0 ) {
135+ idst = row_x * ne0 + i0 / 2 ;
136+ idst += row_indices[channel_x] * set_rows_stride;
137+ }
138+
101139 if (i0 >= n_dims) {
102- dst[idst + i0/ 2 + 0 ] = x[ix + i0/ 2 + 0 ];
103- dst[idst + i0/ 2 + 1 ] = x[ix + i0/ 2 + 1 ];
140+ dst[idst + i0 / 2 + 0 ] = D ( x[ix + i0 / 2 + 0 ]) ;
141+ dst[idst + i0 / 2 + 1 ] = D ( x[ix + i0 / 2 + 1 ]) ;
104142
105143 return ;
106144 }
@@ -117,8 +155,8 @@ static __global__ void rope_neox(
117155 const float x0 = x[ix + 0 ];
118156 const float x1 = x[ix + n_dims/2 ];
119157
120- dst[idst + 0 ] = x0* cos_theta - x1* sin_theta;
121- dst[idst + n_dims/ 2 ] = x0* sin_theta + x1* cos_theta;
158+ dst[idst + 0 ] = D (x0 * cos_theta - x1 * sin_theta) ;
159+ dst[idst + n_dims / 2 ] = D (x0 * sin_theta + x1 * cos_theta) ;
122160}
123161
124162template <bool forward, bool has_ff, typename T>
@@ -238,11 +276,25 @@ static __global__ void rope_vision(
238276 dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
239277}
240278
241- template <bool forward, typename T>
242- static void rope_norm_cuda (
243- const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
244- const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
245- const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
279+ template <bool forward, typename T, typename D>
280+ static void rope_norm_cuda (const T * x,
281+ D * dst,
282+ const int ne0,
283+ const int ne1,
284+ const int s1,
285+ const int s2,
286+ const int n_dims,
287+ const int nr,
288+ const int32_t * pos,
289+ const float freq_scale,
290+ const float freq_base,
291+ const float ext_factor,
292+ const float attn_factor,
293+ const rope_corr_dims corr_dims,
294+ const float * freq_factors,
295+ const int64_t * row_indices,
296+ const int set_rows_stride,
297+ cudaStream_t stream) {
246298 GGML_ASSERT (ne0 % 2 == 0 );
247299 const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
248300 const int n_blocks_x = (ne0 + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
@@ -252,20 +304,34 @@ static void rope_norm_cuda(
252304
253305 if (freq_factors == nullptr ) {
254306 rope_norm<forward, false ><<<block_nums, block_dims, 0 , stream>>> (
255- x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
256- attn_factor, corr_dims, theta_scale, freq_factors );
307+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
308+ freq_factors, row_indices, set_rows_stride );
257309 } else {
258310 rope_norm<forward, true ><<<block_nums, block_dims, 0 , stream>>> (
259- x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
260- attn_factor, corr_dims, theta_scale, freq_factors );
311+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
312+ freq_factors, row_indices, set_rows_stride );
261313 }
262314}
263315
264- template <bool forward, typename T>
265- static void rope_neox_cuda (
266- const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
267- const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
268- const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
316+ template <bool forward, typename T, typename D>
317+ static void rope_neox_cuda (const T * x,
318+ D * dst,
319+ const int ne0,
320+ const int ne1,
321+ const int s1,
322+ const int s2,
323+ const int n_dims,
324+ const int nr,
325+ const int32_t * pos,
326+ const float freq_scale,
327+ const float freq_base,
328+ const float ext_factor,
329+ const float attn_factor,
330+ const rope_corr_dims corr_dims,
331+ const float * freq_factors,
332+ const int64_t * row_indices,
333+ const int set_rows_stride,
334+ cudaStream_t stream) {
269335 GGML_ASSERT (ne0 % 2 == 0 );
270336 const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
271337 const int n_blocks_x = (ne0 + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
@@ -274,13 +340,13 @@ static void rope_neox_cuda(
274340 const float theta_scale = powf (freq_base, -2 .0f /n_dims);
275341
276342 if (freq_factors == nullptr ) {
277- rope_neox<forward, false , T ><<<block_nums, block_dims, 0 , stream>>> (
278- x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
279- attn_factor, corr_dims, theta_scale, freq_factors );
343+ rope_neox<forward, false ><<<block_nums, block_dims, 0 , stream>>> (
344+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
345+ freq_factors, row_indices, set_rows_stride );
280346 } else {
281- rope_neox<forward, true , T ><<<block_nums, block_dims, 0 , stream>>> (
282- x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
283- attn_factor, corr_dims, theta_scale, freq_factors );
347+ rope_neox<forward, true ><<<block_nums, block_dims, 0 , stream>>> (
348+ x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
349+ freq_factors, row_indices, set_rows_stride );
284350 }
285351}
286352
@@ -337,6 +403,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
337403 const ggml_tensor * src0 = dst->src [0 ];
338404 const ggml_tensor * src1 = dst->src [1 ];
339405 const ggml_tensor * src2 = dst->src [2 ];
406+ const ggml_tensor * src3 = dst->src [3 ];
340407
341408 const float * src0_d = (const float *)src0->data ;
342409 const float * src1_d = (const float *)src1->data ;
@@ -346,7 +413,9 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
346413
347414 GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
348415 GGML_ASSERT ( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
349- GGML_ASSERT (src0->type == dst->type );
416+ // When not fused, src0 and dst types must match
417+ // When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16
418+ GGML_ASSERT (src0->type == dst->type || dst->type == GGML_TYPE_F16);
350419
351420 const int64_t ne00 = src0->ne [0 ]; // head dims
352421 const int64_t ne01 = src0->ne [1 ]; // num heads
@@ -399,19 +468,32 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
399468 freq_factors = (const float *) src2->data ;
400469 }
401470
471+ // Row indices for fused ROPE + VIEW + SET_ROWS
472+ const int64_t * row_indices = nullptr ;
473+ int set_rows_stride = 0 ;
474+ if (src3 != nullptr ) {
475+ GGML_ASSERT (src3->type == GGML_TYPE_I64);
476+ row_indices = (const int64_t *) src3->data ;
477+ set_rows_stride = ggml_get_op_params_i32 (dst, 15 );
478+ }
479+
402480 rope_corr_dims corr_dims;
403481 ggml_rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v );
404482
405483 // compute
406484 if (is_neox) {
407- if (src0->type == GGML_TYPE_F32) {
408- rope_neox_cuda<forward>(
409- (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
410- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
411- } else if (src0->type == GGML_TYPE_F16) {
412- rope_neox_cuda<forward>(
413- (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
414- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
485+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
486+ rope_neox_cuda<forward, float , float >((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
487+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
488+ freq_factors, row_indices, set_rows_stride, stream);
489+ } else if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
490+ rope_neox_cuda<forward, float , half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
491+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
492+ freq_factors, row_indices, set_rows_stride, stream);
493+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
494+ rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
495+ pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
496+ freq_factors, row_indices, set_rows_stride, stream);
415497 } else {
416498 GGML_ABORT (" fatal error" );
417499 }
@@ -440,14 +522,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
440522 GGML_ABORT (" fatal error" );
441523 }
442524 } else {
443- if (src0->type == GGML_TYPE_F32) {
444- rope_norm_cuda<forward>(
445- (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
446- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
447- } else if (src0->type == GGML_TYPE_F16) {
448- rope_norm_cuda<forward>(
449- (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
450- freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
525+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
526+ rope_norm_cuda<forward, float , float >((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
527+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
528+ freq_factors, row_indices, set_rows_stride, stream);
529+ } else if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
530+ rope_norm_cuda<forward, float , half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
531+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
532+ freq_factors, row_indices, set_rows_stride, stream);
533+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
534+ rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
535+ pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
536+ freq_factors, row_indices, set_rows_stride, stream);
451537 } else {
452538 GGML_ABORT (" fatal error" );
453539 }
0 commit comments