Skip to content

Commit dc814b8

Browse files
committed
create helper function instead of re-using params
1 parent 607f73b commit dc814b8

File tree

3 files changed

+31
-27
lines changed

3 files changed

+31
-27
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3211,17 +3211,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
32113211
}
32123212

32133213
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
3214-
ggml_tensor * src3 = cgraph->nodes[i + 2]->src[1];
32153214
ggml_tensor * rope = cgraph->nodes[i];
3216-
ggml_tensor * dst = cgraph->nodes[i + 2];
3215+
ggml_tensor * set_rows = cgraph->nodes[i + 2];
32173216

3218-
rope->src[3] = src3;
3219-
rope->data = dst->data;
3220-
rope->type = dst->type;
3221-
3222-
const size_t set_rows_stride = dst->nb[1] / ggml_type_size(dst->type);
3223-
ggml_set_op_params_i32(rope, 15, set_rows_stride);
3224-
ggml_cuda_op_rope(*cuda_ctx, rope);
3217+
ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
32253218
i += 2;
32263219
continue;
32273220
}

ggml/src/ggml-cuda/rope.cu

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include "ggml-cuda/common.cuh"
2+
#include "ggml.h"
13
#include "rope.cuh"
24

35
struct rope_corr_dims {
@@ -399,16 +401,28 @@ static void rope_vision_cuda(
399401
}
400402

401403
template <bool forward>
402-
void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
404+
void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
405+
ggml_tensor * dst,
406+
const ggml_tensor * set_rows = nullptr) {
403407
const ggml_tensor * src0 = dst->src[0];
404408
const ggml_tensor * src1 = dst->src[1];
405409
const ggml_tensor * src2 = dst->src[2];
406-
const ggml_tensor * src3 = dst->src[3];
407410

408411
const float * src0_d = (const float *)src0->data;
409412
const float * src1_d = (const float *)src1->data;
410413

411-
float * dst_d = (float *)dst->data;
414+
void * dst_d = dst->data;
415+
const int64_t * row_indices = nullptr;
416+
ggml_type dst_type = dst->type;
417+
int set_rows_stride = 0;
418+
419+
if (set_rows != nullptr) {
420+
GGML_ASSERT(forward);
421+
dst_d = set_rows->data;
422+
row_indices = (const int64_t *) set_rows->src[1]->data;
423+
dst_type = set_rows->type;
424+
set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
425+
}
412426
cudaStream_t stream = ctx.stream();
413427

414428
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
@@ -468,29 +482,20 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
468482
freq_factors = (const float *) src2->data;
469483
}
470484

471-
// Row indices for fused ROPE + VIEW + SET_ROWS
472-
const int64_t * row_indices = nullptr;
473-
int set_rows_stride = 0;
474-
if (src3 != nullptr) {
475-
GGML_ASSERT(src3->type == GGML_TYPE_I64);
476-
row_indices = (const int64_t *) src3->data;
477-
set_rows_stride = ggml_get_op_params_i32(dst, 15);
478-
}
479-
480485
rope_corr_dims corr_dims;
481486
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
482487

483488
// compute
484489
if (is_neox) {
485-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
490+
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
486491
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
487492
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
488493
freq_factors, row_indices, set_rows_stride, stream);
489-
} else if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
494+
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
490495
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
491496
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
492497
freq_factors, row_indices, set_rows_stride, stream);
493-
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
498+
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
494499
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
495500
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
496501
freq_factors, row_indices, set_rows_stride, stream);
@@ -522,15 +527,15 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
522527
GGML_ABORT("fatal error");
523528
}
524529
} else {
525-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
530+
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
526531
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
527532
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
528533
freq_factors, row_indices, set_rows_stride, stream);
529-
} else if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
534+
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
530535
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
531536
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
532537
freq_factors, row_indices, set_rows_stride, stream);
533-
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
538+
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
534539
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
535540
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
536541
freq_factors, row_indices, set_rows_stride, stream);
@@ -547,3 +552,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
547552
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
548553
ggml_cuda_op_rope_impl<false>(ctx, dst);
549554
}
555+
556+
void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) {
557+
ggml_cuda_op_rope_impl<true>(ctx, rope, set_rows);
558+
}

ggml/src/ggml-cuda/rope.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
66

77
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8+
9+
void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);

0 commit comments

Comments
 (0)