@@ -2992,6 +2992,36 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
29922992}
29932993#endif
29942994
2995+ static bool ggml_cuda_should_fuse_rope_set_rows (const ggml_tensor * rope,
2996+ const ggml_tensor * view,
2997+ const ggml_tensor * set_rows) {
2998+ // ne3 not tested
2999+ if (rope->src [0 ]->ne [3 ] != 1 ) {
3000+ return false ;
3001+ }
3002+
3003+ if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
3004+ return false ;
3005+ }
3006+
3007+ if (set_rows->src [1 ]->type != GGML_TYPE_I64) {
3008+ return false ;
3009+ }
3010+
3011+ // The view should flatten two dims of rope into one dim
3012+ if (!ggml_is_contiguous (view) || view->ne [0 ] != rope->ne [0 ] * rope->ne [1 ]) {
3013+ return false ;
3014+ }
3015+
3016+ // Only norm/neox shaders have the fusion code
3017+ const int mode = ((const int32_t *) rope->op_params )[2 ];
3018+ if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
3019+ return false ;
3020+ }
3021+
3022+ return true ;
3023+ }
3024+
29953025static bool ggml_cuda_can_fuse (const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
29963026#ifndef NDEBUG
29973027 const size_t num_unary = std::count (ops.begin (), ops.end (), GGML_OP_UNARY);
@@ -3067,6 +3097,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
30673097 }
30683098 }
30693099
3100+ if (ops.size () == 3 && ggml_can_fuse_subgraph (cgraph, node_idx, ops, { node_idx + 2 })) {
3101+ const ggml_tensor * rope = cgraph->nodes [node_idx];
3102+ const ggml_tensor * view = cgraph->nodes [node_idx + 1 ];
3103+ const ggml_tensor * set_rows = cgraph->nodes [node_idx + 2 ];
3104+
3105+ if (ggml_cuda_should_fuse_rope_set_rows (rope, view, set_rows)) {
3106+ return true ;
3107+ }
3108+ }
3109+
30703110 if (!ggml_can_fuse (cgraph, node_idx, ops)) {
30713111 return false ;
30723112 }
@@ -3196,6 +3236,15 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31963236 continue ;
31973237 }
31983238
3239+ if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
3240+ ggml_tensor * rope = cgraph->nodes [i];
3241+ ggml_tensor * set_rows = cgraph->nodes [i + 2 ];
3242+
3243+ ggml_cuda_op_rope_fused (*cuda_ctx, rope, set_rows);
3244+ i += 2 ;
3245+ continue ;
3246+ }
3247+
31993248 if (node->op == GGML_OP_ADD) {
32003249 int n_fuse = 0 ;
32013250 ggml_op ops[8 ];
0 commit comments