Skip to content

Commit ea859a2

Browse files
committed
CUDA: add fused rope
1 parent 4146d6a commit ea859a2

File tree

3 files changed

+201
-58
lines changed

3 files changed

+201
-58
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2964,6 +2964,36 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
29642964
}
29652965
#endif
29662966

2967+
static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
2968+
const ggml_tensor * view,
2969+
const ggml_tensor * set_rows) {
2970+
// ne3 not tested
2971+
if (rope->src[0]->ne[3] != 1) {
2972+
return false;
2973+
}
2974+
2975+
if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
2976+
return false;
2977+
}
2978+
2979+
if (set_rows->src[1]->type != GGML_TYPE_I64) {
2980+
return false;
2981+
}
2982+
2983+
// The view should flatten two dims of rope into one dim
2984+
if (!ggml_is_contiguous(view) || view->ne[0] != rope->ne[0] * rope->ne[1]) {
2985+
return false;
2986+
}
2987+
2988+
// Only norm/neox shaders have the fusion code
2989+
const int mode = ((const int32_t *) rope->op_params)[2];
2990+
if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
2991+
return false;
2992+
}
2993+
2994+
return true;
2995+
}
2996+
29672997
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
29682998
#ifndef NDEBUG
29692999
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
@@ -3039,6 +3069,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30393069
}
30403070
}
30413071

3072+
if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
3073+
const ggml_tensor * rope = cgraph->nodes[node_idx];
3074+
const ggml_tensor * view = cgraph->nodes[node_idx + 1];
3075+
const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];
3076+
3077+
if (ggml_cuda_should_fuse_rope_set_rows(rope, view, set_rows)) {
3078+
return true;
3079+
}
3080+
}
3081+
30423082
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
30433083
return false;
30443084
}
@@ -3170,6 +3210,22 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31703210
continue;
31713211
}
31723212

3213+
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
3214+
ggml_tensor * src3 = cgraph->nodes[i + 2]->src[1];
3215+
ggml_tensor * rope = cgraph->nodes[i];
3216+
ggml_tensor * dst = cgraph->nodes[i + 2];
3217+
3218+
rope->src[3] = src3;
3219+
rope->data = dst->data;
3220+
rope->type = dst->type;
3221+
3222+
const size_t set_rows_stride = dst->nb[1] / ggml_type_size(dst->type);
3223+
ggml_set_op_params_i32(rope, 15, set_rows_stride);
3224+
ggml_cuda_op_rope(*cuda_ctx, rope);
3225+
i += 2;
3226+
continue;
3227+
}
3228+
31733229
if (node->op == GGML_OP_MUL) {
31743230
int current_node = i + 1;
31753231
int num_views = 0;

ggml/src/ggml-cuda/rope.cu

Lines changed: 143 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,23 @@ static __device__ void rope_yarn(
3737
}
3838
}
3939

40-
template<bool forward, bool has_ff, typename T>
41-
static __global__ void rope_norm(
42-
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
43-
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
44-
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
40+
template <bool forward, bool has_ff, typename T, typename D>
41+
static __global__ void rope_norm(const T * x,
42+
D * dst,
43+
const int ne0,
44+
const int ne1,
45+
const int s1,
46+
const int s2,
47+
const int n_dims,
48+
const int32_t * pos,
49+
const float freq_scale,
50+
const float ext_factor,
51+
const float attn_factor,
52+
const rope_corr_dims corr_dims,
53+
const float theta_scale,
54+
const float * freq_factors,
55+
const int64_t * row_indices,
56+
const int set_rows_stride) {
4557
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
4658

4759
if (i0 >= ne0) {
@@ -53,12 +65,19 @@ static __global__ void rope_norm(
5365
const int row_x = row_dst % ne1;
5466
const int channel_x = row_dst / ne1;
5567

56-
const int idst = row_dst*ne0 + i0;
68+
int idst = row_dst * ne0 + i0;
5769
const int ix = channel_x*s2 + row_x*s1 + i0;
5870

71+
// Fusion optimization: ROPE + VIEW + SET_ROWS.
72+
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
73+
if (set_rows_stride != 0) {
74+
idst = row_x * ne0 + i0;
75+
idst += row_indices[channel_x] * set_rows_stride;
76+
}
77+
5978
if (i0 >= n_dims) {
60-
dst[idst + 0] = x[ix + 0];
61-
dst[idst + 1] = x[ix + 1];
79+
dst[idst + 0] = D(x[ix + 0]);
80+
dst[idst + 1] = D(x[ix + 1]);
6281

6382
return;
6483
}
@@ -75,15 +94,27 @@ static __global__ void rope_norm(
7594
const float x0 = x[ix + 0];
7695
const float x1 = x[ix + 1];
7796

78-
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
79-
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
97+
dst[idst + 0] = D(x0 * cos_theta - x1 * sin_theta);
98+
dst[idst + 1] = D(x0 * sin_theta + x1 * cos_theta);
8099
}
81100

82-
template<bool forward, bool has_ff, typename T>
83-
static __global__ void rope_neox(
84-
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
85-
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
86-
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
101+
template <bool forward, bool has_ff, typename T, typename D>
102+
static __global__ void rope_neox(const T * x,
103+
D * dst,
104+
const int ne0,
105+
const int ne1,
106+
const int s1,
107+
const int s2,
108+
const int n_dims,
109+
const int32_t * pos,
110+
const float freq_scale,
111+
const float ext_factor,
112+
const float attn_factor,
113+
const rope_corr_dims corr_dims,
114+
const float theta_scale,
115+
const float * freq_factors,
116+
const int64_t * row_indices,
117+
const int set_rows_stride) {
87118
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
88119

89120
if (i0 >= ne0) {
@@ -95,12 +126,19 @@ static __global__ void rope_neox(
95126
const int row_x = row_dst % ne1;
96127
const int channel_x = row_dst / ne1;
97128

98-
const int idst = row_dst*ne0 + i0/2;
129+
int idst = row_dst * ne0 + i0 / 2;
99130
const int ix = channel_x*s2 + row_x*s1 + i0/2;
100131

132+
// Fusion optimization: ROPE + VIEW + SET_ROWS.
133+
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
134+
if (set_rows_stride != 0) {
135+
idst = row_x * ne0 + i0 / 2;
136+
idst += row_indices[channel_x] * set_rows_stride;
137+
}
138+
101139
if (i0 >= n_dims) {
102-
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
103-
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
140+
dst[idst + i0 / 2 + 0] = D(x[ix + i0 / 2 + 0]);
141+
dst[idst + i0 / 2 + 1] = D(x[ix + i0 / 2 + 1]);
104142

105143
return;
106144
}
@@ -117,8 +155,8 @@ static __global__ void rope_neox(
117155
const float x0 = x[ix + 0];
118156
const float x1 = x[ix + n_dims/2];
119157

120-
dst[idst + 0] = x0*cos_theta - x1*sin_theta;
121-
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
158+
dst[idst + 0] = D(x0 * cos_theta - x1 * sin_theta);
159+
dst[idst + n_dims / 2] = D(x0 * sin_theta + x1 * cos_theta);
122160
}
123161

124162
template<bool forward, bool has_ff, typename T>
@@ -238,11 +276,25 @@ static __global__ void rope_vision(
238276
dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
239277
}
240278

241-
template<bool forward, typename T>
242-
static void rope_norm_cuda(
243-
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
244-
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
245-
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
279+
template <bool forward, typename T, typename D>
280+
static void rope_norm_cuda(const T * x,
281+
D * dst,
282+
const int ne0,
283+
const int ne1,
284+
const int s1,
285+
const int s2,
286+
const int n_dims,
287+
const int nr,
288+
const int32_t * pos,
289+
const float freq_scale,
290+
const float freq_base,
291+
const float ext_factor,
292+
const float attn_factor,
293+
const rope_corr_dims corr_dims,
294+
const float * freq_factors,
295+
const int64_t * row_indices,
296+
const int set_rows_stride,
297+
cudaStream_t stream) {
246298
GGML_ASSERT(ne0 % 2 == 0);
247299
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
248300
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -252,20 +304,34 @@ static void rope_norm_cuda(
252304

253305
if (freq_factors == nullptr) {
254306
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
255-
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
256-
attn_factor, corr_dims, theta_scale, freq_factors);
307+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
308+
freq_factors, row_indices, set_rows_stride);
257309
} else {
258310
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
259-
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
260-
attn_factor, corr_dims, theta_scale, freq_factors);
311+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
312+
freq_factors, row_indices, set_rows_stride);
261313
}
262314
}
263315

264-
template<bool forward, typename T>
265-
static void rope_neox_cuda(
266-
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
267-
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
268-
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
316+
template <bool forward, typename T, typename D>
317+
static void rope_neox_cuda(const T * x,
318+
D * dst,
319+
const int ne0,
320+
const int ne1,
321+
const int s1,
322+
const int s2,
323+
const int n_dims,
324+
const int nr,
325+
const int32_t * pos,
326+
const float freq_scale,
327+
const float freq_base,
328+
const float ext_factor,
329+
const float attn_factor,
330+
const rope_corr_dims corr_dims,
331+
const float * freq_factors,
332+
const int64_t * row_indices,
333+
const int set_rows_stride,
334+
cudaStream_t stream) {
269335
GGML_ASSERT(ne0 % 2 == 0);
270336
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
271337
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -274,13 +340,13 @@ static void rope_neox_cuda(
274340
const float theta_scale = powf(freq_base, -2.0f/n_dims);
275341

276342
if (freq_factors == nullptr) {
277-
rope_neox<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
278-
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
279-
attn_factor, corr_dims, theta_scale, freq_factors);
343+
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
344+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
345+
freq_factors, row_indices, set_rows_stride);
280346
} else {
281-
rope_neox<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
282-
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
283-
attn_factor, corr_dims, theta_scale, freq_factors);
347+
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
348+
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
349+
freq_factors, row_indices, set_rows_stride);
284350
}
285351
}
286352

@@ -337,6 +403,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
337403
const ggml_tensor * src0 = dst->src[0];
338404
const ggml_tensor * src1 = dst->src[1];
339405
const ggml_tensor * src2 = dst->src[2];
406+
const ggml_tensor * src3 = dst->src[3];
340407

341408
const float * src0_d = (const float *)src0->data;
342409
const float * src1_d = (const float *)src1->data;
@@ -346,7 +413,9 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
346413

347414
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
348415
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
349-
GGML_ASSERT(src0->type == dst->type);
416+
// When not fused, src0 and dst types must match
417+
// When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16
418+
GGML_ASSERT(src0->type == dst->type || dst->type == GGML_TYPE_F16);
350419

351420
const int64_t ne00 = src0->ne[0]; // head dims
352421
const int64_t ne01 = src0->ne[1]; // num heads
@@ -399,19 +468,32 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
399468
freq_factors = (const float *) src2->data;
400469
}
401470

471+
// Row indices for fused ROPE + VIEW + SET_ROWS
472+
const int64_t * row_indices = nullptr;
473+
int set_rows_stride = 0;
474+
if (src3 != nullptr) {
475+
GGML_ASSERT(src3->type == GGML_TYPE_I64);
476+
row_indices = (const int64_t *) src3->data;
477+
set_rows_stride = ggml_get_op_params_i32(dst, 15);
478+
}
479+
402480
rope_corr_dims corr_dims;
403481
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
404482

405483
// compute
406484
if (is_neox) {
407-
if (src0->type == GGML_TYPE_F32) {
408-
rope_neox_cuda<forward>(
409-
(const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
410-
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
411-
} else if (src0->type == GGML_TYPE_F16) {
412-
rope_neox_cuda<forward>(
413-
(const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
414-
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
485+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
486+
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
487+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
488+
freq_factors, row_indices, set_rows_stride, stream);
489+
} else if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
490+
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
491+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
492+
freq_factors, row_indices, set_rows_stride, stream);
493+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
494+
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
495+
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
496+
freq_factors, row_indices, set_rows_stride, stream);
415497
} else {
416498
GGML_ABORT("fatal error");
417499
}
@@ -440,14 +522,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
440522
GGML_ABORT("fatal error");
441523
}
442524
} else {
443-
if (src0->type == GGML_TYPE_F32) {
444-
rope_norm_cuda<forward>(
445-
(const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
446-
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
447-
} else if (src0->type == GGML_TYPE_F16) {
448-
rope_norm_cuda<forward>(
449-
(const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
450-
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
525+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
526+
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
527+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
528+
freq_factors, row_indices, set_rows_stride, stream);
529+
} else if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
530+
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
531+
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
532+
freq_factors, row_indices, set_rows_stride, stream);
533+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
534+
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
535+
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
536+
freq_factors, row_indices, set_rows_stride, stream);
451537
} else {
452538
GGML_ABORT("fatal error");
453539
}

src/llama-graph.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1592,14 +1592,15 @@ ggml_tensor * llm_graph_context::build_attn(
15921592
int il) const {
15931593
// these nodes are added to the graph together so that they are not reordered
15941594
// by doing so, the number of splits in the graph is reduced
1595+
// expand k later to enable rope fusion which directly writes into k-v cache
15951596
ggml_build_forward_expand(gf, q_cur);
1596-
ggml_build_forward_expand(gf, k_cur);
15971597
ggml_build_forward_expand(gf, v_cur);
15981598

15991599
const auto * mctx_cur = inp->mctx;
16001600

16011601
// store to KV cache
16021602
{
1603+
ggml_build_forward_expand(gf, k_cur);
16031604
const auto & k_idxs = inp->get_k_idxs();
16041605
const auto & v_idxs = inp->get_v_idxs();
16051606

0 commit comments

Comments
 (0)