Skip to content

Commit 406c867

Browse files
committed
create helper function instead of re-using params
1 parent e510907 commit 406c867

File tree

3 files changed

+31
-27
lines changed

3 files changed

+31
-27
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3210,17 +3210,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
32103210
}
32113211

32123212
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
3213-
ggml_tensor * src3 = cgraph->nodes[i + 2]->src[1];
32143213
ggml_tensor * rope = cgraph->nodes[i];
3215-
ggml_tensor * dst = cgraph->nodes[i + 2];
3214+
ggml_tensor * set_rows = cgraph->nodes[i + 2];
32163215

3217-
rope->src[3] = src3;
3218-
rope->data = dst->data;
3219-
rope->type = dst->type;
3220-
3221-
const size_t set_rows_stride = dst->nb[1] / ggml_type_size(dst->type);
3222-
ggml_set_op_params_i32(rope, 15, set_rows_stride);
3223-
ggml_cuda_op_rope(*cuda_ctx, rope);
3216+
ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
32243217
i += 2;
32253218
continue;
32263219
}

ggml/src/ggml-cuda/rope.cu

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include "ggml-cuda/common.cuh"
2+
#include "ggml.h"
13
#include "rope.cuh"
24

35
struct rope_corr_dims {
@@ -387,16 +389,28 @@ static void rope_vision_cuda(
387389
}
388390

389391
template <bool forward>
390-
void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
392+
void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
393+
ggml_tensor * dst,
394+
const ggml_tensor * set_rows = nullptr) {
391395
const ggml_tensor * src0 = dst->src[0];
392396
const ggml_tensor * src1 = dst->src[1];
393397
const ggml_tensor * src2 = dst->src[2];
394-
const ggml_tensor * src3 = dst->src[3];
395398

396399
const float * src0_d = (const float *)src0->data;
397400
const float * src1_d = (const float *)src1->data;
398401

399-
float * dst_d = (float *)dst->data;
402+
void * dst_d = dst->data;
403+
const int64_t * row_indices = nullptr;
404+
ggml_type dst_type = dst->type;
405+
int set_rows_stride = 0;
406+
407+
if (set_rows != nullptr) {
408+
GGML_ASSERT(forward);
409+
dst_d = set_rows->data;
410+
row_indices = (const int64_t *) set_rows->src[1]->data;
411+
dst_type = set_rows->type;
412+
set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
413+
}
400414
cudaStream_t stream = ctx.stream();
401415

402416
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
@@ -455,29 +469,20 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
455469
freq_factors = (const float *) src2->data;
456470
}
457471

458-
// Row indices for fused ROPE + VIEW + SET_ROWS
459-
const int64_t * row_indices = nullptr;
460-
int set_rows_stride = 0;
461-
if (src3 != nullptr) {
462-
GGML_ASSERT(src3->type == GGML_TYPE_I64);
463-
row_indices = (const int64_t *) src3->data;
464-
set_rows_stride = ggml_get_op_params_i32(dst, 15);
465-
}
466-
467472
rope_corr_dims corr_dims;
468473
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
469474

470475
// compute
471476
if (is_neox) {
472-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
477+
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
473478
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
474479
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
475480
freq_factors, row_indices, set_rows_stride, stream);
476-
} else if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
481+
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
477482
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
478483
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
479484
freq_factors, row_indices, set_rows_stride, stream);
480-
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
485+
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
481486
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
482487
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
483488
freq_factors, row_indices, set_rows_stride, stream);
@@ -509,15 +514,15 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
509514
GGML_ABORT("fatal error");
510515
}
511516
} else {
512-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
517+
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
513518
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
514519
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
515520
freq_factors, row_indices, set_rows_stride, stream);
516-
} else if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
521+
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
517522
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
518523
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
519524
freq_factors, row_indices, set_rows_stride, stream);
520-
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
525+
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
521526
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
522527
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
523528
freq_factors, row_indices, set_rows_stride, stream);
@@ -534,3 +539,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
534539
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
535540
ggml_cuda_op_rope_impl<false>(ctx, dst);
536541
}
542+
543+
void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) {
544+
ggml_cuda_op_rope_impl<true>(ctx, rope, set_rows);
545+
}

ggml/src/ggml-cuda/rope.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
66

77
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
8+
9+
void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);

0 commit comments

Comments
 (0)