1+ #include " ggml-cuda/common.cuh"
2+ #include " ggml.h"
13#include " rope.cuh"
24
35struct rope_corr_dims {
@@ -387,16 +389,28 @@ static void rope_vision_cuda(
387389}
388390
389391template <bool forward>
390- void ggml_cuda_op_rope_impl (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
392+ void ggml_cuda_op_rope_impl (ggml_backend_cuda_context & ctx,
393+ ggml_tensor * dst,
394+ const ggml_tensor * set_rows = nullptr ) {
391395 const ggml_tensor * src0 = dst->src [0 ];
392396 const ggml_tensor * src1 = dst->src [1 ];
393397 const ggml_tensor * src2 = dst->src [2 ];
394- const ggml_tensor * src3 = dst->src [3 ];
395398
396399 const float * src0_d = (const float *)src0->data ;
397400 const float * src1_d = (const float *)src1->data ;
398401
399- float * dst_d = (float *)dst->data ;
402+ void * dst_d = dst->data ;
403+ const int64_t * row_indices = nullptr ;
404+ ggml_type dst_type = dst->type ;
405+ int set_rows_stride = 0 ;
406+
407+ if (set_rows != nullptr ) {
408+ GGML_ASSERT (forward);
409+ dst_d = set_rows->data ;
410+ row_indices = (const int64_t *) set_rows->src [1 ]->data ;
411+ dst_type = set_rows->type ;
412+ set_rows_stride = set_rows->nb [1 ] / ggml_type_size (set_rows->type );
413+ }
400414 cudaStream_t stream = ctx.stream ();
401415
402416 GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
@@ -455,29 +469,20 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
455469 freq_factors = (const float *) src2->data ;
456470 }
457471
458- // Row indices for fused ROPE + VIEW + SET_ROWS
459- const int64_t * row_indices = nullptr ;
460- int set_rows_stride = 0 ;
461- if (src3 != nullptr ) {
462- GGML_ASSERT (src3->type == GGML_TYPE_I64);
463- row_indices = (const int64_t *) src3->data ;
464- set_rows_stride = ggml_get_op_params_i32 (dst, 15 );
465- }
466-
467472 rope_corr_dims corr_dims;
468473 ggml_rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v );
469474
470475 // compute
471476 if (is_neox) {
472- if (src0->type == GGML_TYPE_F32 && dst-> type == GGML_TYPE_F32) {
477+ if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
473478 rope_neox_cuda<forward, float , float >((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
474479 nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
475480 freq_factors, row_indices, set_rows_stride, stream);
476- } else if (src0->type == GGML_TYPE_F32 && dst-> type == GGML_TYPE_F16) {
481+ } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
477482 rope_neox_cuda<forward, float , half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
478483 nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
479484 freq_factors, row_indices, set_rows_stride, stream);
480- } else if (src0->type == GGML_TYPE_F16 && dst-> type == GGML_TYPE_F16) {
485+ } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
481486 rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
482487 pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
483488 freq_factors, row_indices, set_rows_stride, stream);
@@ -509,15 +514,15 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
509514 GGML_ABORT (" fatal error" );
510515 }
511516 } else {
512- if (src0->type == GGML_TYPE_F32 && dst-> type == GGML_TYPE_F32) {
517+ if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
513518 rope_norm_cuda<forward, float , float >((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
514519 nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
515520 freq_factors, row_indices, set_rows_stride, stream);
516- } else if (src0->type == GGML_TYPE_F32 && dst-> type == GGML_TYPE_F16) {
521+ } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
517522 rope_norm_cuda<forward, float , half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
518523 nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
519524 freq_factors, row_indices, set_rows_stride, stream);
520- } else if (src0->type == GGML_TYPE_F16 && dst-> type == GGML_TYPE_F16) {
525+ } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
521526 rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
522527 pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
523528 freq_factors, row_indices, set_rows_stride, stream);
@@ -534,3 +539,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
534539void ggml_cuda_op_rope_back (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
535540 ggml_cuda_op_rope_impl<false >(ctx, dst);
536541}
542+
543+ void ggml_cuda_op_rope_fused (ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) {
544+ ggml_cuda_op_rope_impl<true >(ctx, rope, set_rows);
545+ }
0 commit comments