Skip to content

Commit a90eb94

Browse files
authored
CUDA: fuse rope + set_rows (#16884)
* CUDA: add fused rope * move k forward_expand up * create helper function instead of re-using params * make assert statement more in line with comment * rope_norm: coalesced writes to global mem
1 parent 07751f8 commit a90eb94

File tree

4 files changed

+215
-61
lines changed

4 files changed

+215
-61
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2992,6 +2992,36 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
29922992
}
29932993
#endif
29942994

2995+
static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
2996+
const ggml_tensor * view,
2997+
const ggml_tensor * set_rows) {
2998+
// ne3 not tested
2999+
if (rope->src[0]->ne[3] != 1) {
3000+
return false;
3001+
}
3002+
3003+
if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
3004+
return false;
3005+
}
3006+
3007+
if (set_rows->src[1]->type != GGML_TYPE_I64) {
3008+
return false;
3009+
}
3010+
3011+
// The view should flatten two dims of rope into one dim
3012+
if (!ggml_is_contiguous(view) || view->ne[0] != rope->ne[0] * rope->ne[1]) {
3013+
return false;
3014+
}
3015+
3016+
// Only norm/neox shaders have the fusion code
3017+
const int mode = ((const int32_t *) rope->op_params)[2];
3018+
if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
3019+
return false;
3020+
}
3021+
3022+
return true;
3023+
}
3024+
29953025
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
29963026
#ifndef NDEBUG
29973027
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
@@ -3067,6 +3097,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30673097
}
30683098
}
30693099

3100+
if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
3101+
const ggml_tensor * rope = cgraph->nodes[node_idx];
3102+
const ggml_tensor * view = cgraph->nodes[node_idx + 1];
3103+
const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];
3104+
3105+
if (ggml_cuda_should_fuse_rope_set_rows(rope, view, set_rows)) {
3106+
return true;
3107+
}
3108+
}
3109+
30703110
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
30713111
return false;
30723112
}
@@ -3196,6 +3236,15 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31963236
continue;
31973237
}
31983238

3239+
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
3240+
ggml_tensor * rope = cgraph->nodes[i];
3241+
ggml_tensor * set_rows = cgraph->nodes[i + 2];
3242+
3243+
ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
3244+
i += 2;
3245+
continue;
3246+
}
3247+
31993248
if (node->op == GGML_OP_ADD) {
32003249
int n_fuse = 0;
32013250
ggml_op ops[8];

0 commit comments

Comments
 (0)