Skip to content

Commit eac4bde

Browse files
committed
CUDA: add fused rope
1 parent 229bf68 commit eac4bde

File tree

3 files changed

+201
-58
lines changed

3 files changed

+201
-58
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2963,6 +2963,36 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
29632963
}
29642964
#endif
29652965

2966+
static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
2967+
const ggml_tensor * view,
2968+
const ggml_tensor * set_rows) {
2969+
// ne3 not tested
2970+
if (rope->src[0]->ne[3] != 1) {
2971+
return false;
2972+
}
2973+
2974+
if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
2975+
return false;
2976+
}
2977+
2978+
if (set_rows->src[1]->type != GGML_TYPE_I64) {
2979+
return false;
2980+
}
2981+
2982+
// The view should flatten two dims of rope into one dim
2983+
if (!ggml_is_contiguous(view) || view->ne[0] != rope->ne[0] * rope->ne[1]) {
2984+
return false;
2985+
}
2986+
2987+
// Only norm/neox shaders have the fusion code
2988+
const int mode = ((const int32_t *) rope->op_params)[2];
2989+
if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
2990+
return false;
2991+
}
2992+
2993+
return true;
2994+
}
2995+
29662996
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
29672997
#ifndef NDEBUG
29682998
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
@@ -3038,6 +3068,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30383068
}
30393069
}
30403070

3071+
if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
3072+
const ggml_tensor * rope = cgraph->nodes[node_idx];
3073+
const ggml_tensor * view = cgraph->nodes[node_idx + 1];
3074+
const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];
3075+
3076+
if (ggml_cuda_should_fuse_rope_set_rows(rope, view, set_rows)) {
3077+
return true;
3078+
}
3079+
}
3080+
30413081
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
30423082
return false;
30433083
}
@@ -3169,6 +3209,22 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31693209
continue;
31703210
}
31713211

3212+
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
3213+
ggml_tensor * src3 = cgraph->nodes[i + 2]->src[1];
3214+
ggml_tensor * rope = cgraph->nodes[i];
3215+
ggml_tensor * dst = cgraph->nodes[i + 2];
3216+
3217+
rope->src[3] = src3;
3218+
rope->data = dst->data;
3219+
rope->type = dst->type;
3220+
3221+
const size_t set_rows_stride = dst->nb[1] / ggml_type_size(dst->type);
3222+
ggml_set_op_params_i32(rope, 15, set_rows_stride);
3223+
ggml_cuda_op_rope(*cuda_ctx, rope);
3224+
i += 2;
3225+
continue;
3226+
}
3227+
31723228
if (node->op == GGML_OP_ADD) {
31733229
int n_fuse = 0;
31743230
ggml_op ops[8];

ggml/src/ggml-cuda/rope.cu

Lines changed: 143 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,23 @@ static __device__ void rope_yarn(
3737
}
3838
}
3939

40-
template<bool forward, bool has_ff, typename T>
41-
static __global__ void rope_norm(
42-
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
43-
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
44-
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
40+
template <bool forward, bool has_ff, typename T, typename D>
41+
static __global__ void rope_norm(const T * x,
42+
D * dst,
43+
const int ne0,
44+
const int ne1,
45+
const int s1,
46+
const int s2,
47+
const int n_dims,
48+
const int32_t * pos,
49+
const float freq_scale,
50+
const float ext_factor,
51+
const float attn_factor,
52+
const rope_corr_dims corr_dims,
53+
const float theta_scale,
54+
const float * freq_factors,
55+
const int64_t * row_indices,
56+
const int set_rows_stride) {
4557
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
4658

4759
if (i0 >= ne0) {
@@ -53,12 +65,19 @@ static __global__ void rope_norm(
5365
const int row_x = row_dst % ne1;
5466
const int channel_x = row_dst / ne1;
5567

56-
const int idst = row_dst*ne0 + i0;
68+
int idst = row_dst * ne0 + i0;
5769
const int ix = channel_x*s2 + row_x*s1 + i0;
5870

71+
// Fusion optimization: ROPE + VIEW + SET_ROWS.
72+
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
73+
if (set_rows_stride != 0) {
74+
idst = row_x * ne0 + i0;
75+
idst += row_indices[channel_x] * set_rows_stride;
76+
}
77+
5978
if (i0 >= n_dims) {
60-
dst[idst + 0] = x[ix + 0];
61-
dst[idst + 1] = x[ix + 1];
79+
dst[idst + 0] = D(x[ix + 0]);
80+
dst[idst + 1] = D(x[ix + 1]);
6281

6382
return;
6483
}
@@ -75,15 +94,27 @@ static __global__ void rope_norm(
7594
const float x0 = x[ix + 0];
7695
const float x1 = x[ix + 1];
7796

78-
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
79-
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
97+
dst[idst + 0] = D(x0 * cos_theta - x1 * sin_theta);
98+
dst[idst + 1] = D(x0 * sin_theta + x1 * cos_theta);
8099
}
81100

82-
template<bool forward, bool has_ff, typename T>
83-
static __global__ void rope_neox(
84-
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
85-
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
86-
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
101+
template <bool forward, bool has_ff, typename T, typename D>
102+
static __global__ void rope_neox(const T * x,
103+
D * dst,
104+
const int ne0,
105+
const int ne1,
106+
const int s1,
107+
const int s2,
108+
const int n_dims,
109+
const int32_t * pos,
110+
const float freq_scale,
111+
const float ext_factor,
112+
const float attn_factor,
113+
const rope_corr_dims corr_dims,
114+
const float theta_scale,
115+
const float * freq_factors,
116+
const int64_t * row_indices,
117+
const int set_rows_stride) {
87118
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
88119

89120
if (i0 >= ne0) {
@@ -95,12 +126,19 @@ static __global__ void rope_neox(
95126
const int row_x = row_dst % ne1;
96127
const int channel_x = row_dst / ne1;
97128

98-
const int idst = row_dst*ne0 + i0/2;
129+
int idst = row_dst * ne0 + i0 / 2;
99130
const int ix = channel_x*s2 + row_x*s1 + i0/2;
100131

132+
// Fusion optimization: ROPE + VIEW + SET_ROWS.
133+
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
134+
if (set_rows_stride != 0) {
135+
idst = row_x * ne0 + i0 / 2;
136+
idst += row_indices[channel_x] * set_rows_stride;
137+
}
138+
101139
if (i0 >= n_dims) {
102-
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
103-
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
140+
dst[idst + i0 / 2 + 0] = D(x[ix + i0 / 2 + 0]);
141+
dst[idst + i0 / 2 + 1] = D(x[ix + i0 / 2 + 1]);
104142

105143
return;
106144
}
@@ -117,8 +155,8 @@ static __global__ void rope_neox(
117155
const float x0 = x[ix + 0];
118156
const float x1 = x[ix + n_dims/2];
119157

120-
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
121-
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
158+
dst[idst + 0] = D(x0 * cos_theta - x1 * sin_theta);
159+
dst[idst + n_dims / 2] = D(x0 * sin_theta + x1 * cos_theta);
122160
}
123161

124162
template<bool forward, bool has_ff, typename T>
@@ -226,11 +264,25 @@ static __global__ void rope_vision(
226264
dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
227265
}
228266

229-
template<bool forward, typename T>
230-
static void rope_norm_cuda(
231-
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
232-
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
233-
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
267+
template <bool forward, typename T, typename D>
268+
static void rope_norm_cuda(const T * x,
269+
D * dst,
270+
const int ne0,
271+
const int ne1,
272+
const int s1,
273+
const int s2,
274+
const int n_dims,
275+
const int nr,
276+
const int32_t * pos,
277+
const float freq_scale,
278+
const float freq_base,
279+
const float ext_factor,
280+
const float attn_factor,
281+
const rope_corr_dims corr_dims,
282+
const float * freq_factors,
283+
const int64_t * row_indices,
284+
const int set_rows_stride,
285+
cudaStream_t stream) {
234286
GGML_ASSERT(ne0 % 2 == 0);
235287
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
236288
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -240,20 +292,34 @@ static void rope_norm_cuda(
240292

241293
if (freq_factors == nullptr) {
242294
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
243-
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
244-
attn_factor, corr_dims, theta_scale, freq_factors);
295+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
296+
freq_factors, row_indices, set_rows_stride);
245297
} else {
246298
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
247-
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
248-
attn_factor, corr_dims, theta_scale, freq_factors);
299+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
300+
freq_factors, row_indices, set_rows_stride);
249301
}
250302
}
251303

252-
template<bool forward, typename T>
253-
static void rope_neox_cuda(
254-
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
255-
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
256-
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
304+
template <bool forward, typename T, typename D>
305+
static void rope_neox_cuda(const T * x,
306+
D * dst,
307+
const int ne0,
308+
const int ne1,
309+
const int s1,
310+
const int s2,
311+
const int n_dims,
312+
const int nr,
313+
const int32_t * pos,
314+
const float freq_scale,
315+
const float freq_base,
316+
const float ext_factor,
317+
const float attn_factor,
318+
const rope_corr_dims corr_dims,
319+
const float * freq_factors,
320+
const int64_t * row_indices,
321+
const int set_rows_stride,
322+
cudaStream_t stream) {
257323
GGML_ASSERT(ne0 % 2 == 0);
258324
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
259325
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -262,13 +328,13 @@ static void rope_neox_cuda(
262328
const float theta_scale = powf(freq_base, -2.0f/n_dims);
263329

264330
if (freq_factors == nullptr) {
265-
rope_neox<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
266-
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
267-
attn_factor, corr_dims, theta_scale, freq_factors);
331+
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
332+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
333+
freq_factors, row_indices, set_rows_stride);
268334
} else {
269-
rope_neox<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
270-
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
271-
attn_factor, corr_dims, theta_scale, freq_factors);
335+
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
336+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
337+
freq_factors, row_indices, set_rows_stride);
272338
}
273339
}
274340

@@ -325,6 +391,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
325391
const ggml_tensor * src0 = dst->src[0];
326392
const ggml_tensor * src1 = dst->src[1];
327393
const ggml_tensor * src2 = dst->src[2];
394+
const ggml_tensor * src3 = dst->src[3];
328395

329396
const float * src0_d = (const float *)src0->data;
330397
const float * src1_d = (const float *)src1->data;
@@ -334,7 +401,9 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
334401

335402
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
336403
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
337-
GGML_ASSERT(src0->type == dst->type);
404+
// When not fused, src0 and dst types must match
405+
// When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16
406+
GGML_ASSERT(src0->type == dst->type || dst->type == GGML_TYPE_F16);
338407

339408
const int64_t ne00 = src0->ne[0]; // head dims
340409
const int64_t ne01 = src0->ne[1]; // num heads
@@ -386,19 +455,32 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
386455
freq_factors = (const float *) src2->data;
387456
}
388457

458+
// Row indices for fused ROPE + VIEW + SET_ROWS
459+
const int64_t * row_indices = nullptr;
460+
int set_rows_stride = 0;
461+
if (src3 != nullptr) {
462+
GGML_ASSERT(src3->type == GGML_TYPE_I64);
463+
row_indices = (const int64_t *) src3->data;
464+
set_rows_stride = ggml_get_op_params_i32(dst, 15);
465+
}
466+
389467
rope_corr_dims corr_dims;
390468
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
391469

392470
// compute
393471
if (is_neox) {
394-
if (src0->type == GGML_TYPE_F32) {
395-
rope_neox_cuda<forward>(
396-
(const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
397-
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
398-
} else if (src0->type == GGML_TYPE_F16) {
399-
rope_neox_cuda<forward>(
400-
(const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
401-
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
472+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
473+
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
474+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
475+
freq_factors, row_indices, set_rows_stride, stream);
476+
} else if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
477+
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
478+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
479+
freq_factors, row_indices, set_rows_stride, stream);
480+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
481+
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
482+
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
483+
freq_factors, row_indices, set_rows_stride, stream);
402484
} else {
403485
GGML_ABORT("fatal error");
404486
}
@@ -427,14 +509,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
427509
GGML_ABORT("fatal error");
428510
}
429511
} else {
430-
if (src0->type == GGML_TYPE_F32) {
431-
rope_norm_cuda<forward>(
432-
(const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
433-
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
434-
} else if (src0->type == GGML_TYPE_F16) {
435-
rope_norm_cuda<forward>(
436-
(const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
437-
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
512+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
513+
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
514+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
515+
freq_factors, row_indices, set_rows_stride, stream);
516+
} else if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
517+
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
518+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
519+
freq_factors, row_indices, set_rows_stride, stream);
520+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
521+
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
522+
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
523+
freq_factors, row_indices, set_rows_stride, stream);
438524
} else {
439525
GGML_ABORT("fatal error");
440526
}

src/llama-graph.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1592,14 +1592,15 @@ ggml_tensor * llm_graph_context::build_attn(
15921592
int il) const {
15931593
// these nodes are added to the graph together so that they are not reordered
15941594
// by doing so, the number of splits in the graph is reduced
1595+
// expand k later to enable rope fusion which directly writes into k-v cache
15951596
ggml_build_forward_expand(gf, q_cur);
1596-
ggml_build_forward_expand(gf, k_cur);
15971597
ggml_build_forward_expand(gf, v_cur);
15981598

15991599
const auto * mctx_cur = inp->mctx;
16001600

16011601
// store to KV cache
16021602
{
1603+
ggml_build_forward_expand(gf, k_cur);
16031604
const auto & k_idxs = inp->get_k_idxs();
16041605
const auto & v_idxs = inp->get_v_idxs();
16051606

0 commit comments

Comments
 (0)