1+ #include " ggml-cuda/common.cuh"
2+ #include " ggml.h"
13#include " rope.cuh"
24
35struct rope_corr_dims {
@@ -399,16 +401,28 @@ static void rope_vision_cuda(
399401}
400402
401403template <bool forward>
402- void ggml_cuda_op_rope_impl (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
404+ void ggml_cuda_op_rope_impl (ggml_backend_cuda_context & ctx,
405+ ggml_tensor * dst,
406+ const ggml_tensor * set_rows = nullptr ) {
403407 const ggml_tensor * src0 = dst->src [0 ];
404408 const ggml_tensor * src1 = dst->src [1 ];
405409 const ggml_tensor * src2 = dst->src [2 ];
406- const ggml_tensor * src3 = dst->src [3 ];
407410
408411 const float * src0_d = (const float *)src0->data ;
409412 const float * src1_d = (const float *)src1->data ;
410413
411- float * dst_d = (float *)dst->data ;
414+ void * dst_d = dst->data ;
415+ const int64_t * row_indices = nullptr ;
416+ ggml_type dst_type = dst->type ;
417+ int set_rows_stride = 0 ;
418+
419+ if (set_rows != nullptr ) {
420+ GGML_ASSERT (forward);
421+ dst_d = set_rows->data ;
422+ row_indices = (const int64_t *) set_rows->src [1 ]->data ;
423+ dst_type = set_rows->type ;
424+ set_rows_stride = set_rows->nb [1 ] / ggml_type_size (set_rows->type );
425+ }
412426 cudaStream_t stream = ctx.stream ();
413427
414428 GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
@@ -468,29 +482,20 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
468482 freq_factors = (const float *) src2->data ;
469483 }
470484
471- // Row indices for fused ROPE + VIEW + SET_ROWS
472- const int64_t * row_indices = nullptr ;
473- int set_rows_stride = 0 ;
474- if (src3 != nullptr ) {
475- GGML_ASSERT (src3->type == GGML_TYPE_I64);
476- row_indices = (const int64_t *) src3->data ;
477- set_rows_stride = ggml_get_op_params_i32 (dst, 15 );
478- }
479-
480485 rope_corr_dims corr_dims;
481486 ggml_rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v );
482487
483488 // compute
484489 if (is_neox) {
485- if (src0->type == GGML_TYPE_F32 && dst-> type == GGML_TYPE_F32) {
490+ if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
486491 rope_neox_cuda<forward, float , float >((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
487492 nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
488493 freq_factors, row_indices, set_rows_stride, stream);
489- } else if (src0->type == GGML_TYPE_F32 && dst-> type == GGML_TYPE_F16) {
494+ } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
490495 rope_neox_cuda<forward, float , half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
491496 nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
492497 freq_factors, row_indices, set_rows_stride, stream);
493- } else if (src0->type == GGML_TYPE_F16 && dst-> type == GGML_TYPE_F16) {
498+ } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
494499 rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
495500 pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
496501 freq_factors, row_indices, set_rows_stride, stream);
@@ -522,15 +527,15 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
522527 GGML_ABORT (" fatal error" );
523528 }
524529 } else {
525- if (src0->type == GGML_TYPE_F32 && dst-> type == GGML_TYPE_F32) {
530+ if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
526531 rope_norm_cuda<forward, float , float >((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
527532 nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
528533 freq_factors, row_indices, set_rows_stride, stream);
529- } else if (src0->type == GGML_TYPE_F32 && dst-> type == GGML_TYPE_F16) {
534+ } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
530535 rope_norm_cuda<forward, float , half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
531536 nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
532537 freq_factors, row_indices, set_rows_stride, stream);
533- } else if (src0->type == GGML_TYPE_F16 && dst-> type == GGML_TYPE_F16) {
538+ } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
534539 rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
535540 pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
536541 freq_factors, row_indices, set_rows_stride, stream);
@@ -547,3 +552,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
547552void ggml_cuda_op_rope_back (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
548553 ggml_cuda_op_rope_impl<false >(ctx, dst);
549554}
555+
556+ void ggml_cuda_op_rope_fused (ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) {
557+ ggml_cuda_op_rope_impl<true >(ctx, rope, set_rows);
558+ }
0 commit comments