@@ -1955,6 +1955,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19551955static  int  ggml_metal_encode_node (
19561956                        ggml_backend_t    backend,
19571957                                   int    idx,
1958+                                    int    idx_end,
19581959          id <MTLComputeCommandEncoder >   encoder,
19591960            struct  ggml_metal_mem_pool * mem_pool) {
19601961    struct  ggml_backend_metal_context        * ctx     = backend->context ;
@@ -2181,7 +2182,9 @@ static int ggml_metal_encode_node(
21812182                    size_t  offs_fuse;
21822183                    id <MTLBuffer > id_fuse;
21832184
2184-                     for  (n_fuse = 0 ; n_fuse <= 6 ; ++n_fuse) {
2185+                     //  note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes
2186+                     //        across splits. idx_end indicates the last node in the current split
2187+                     for  (n_fuse = 0 ; n_fuse <= 6  && idx + n_fuse + 1  < idx_end; ++n_fuse) {
21852188                        if  (!ggml_can_fuse (gf, idx + n_fuse, ops + n_fuse, 2 )) {
21862189                            break ;
21872190                        }
@@ -4288,7 +4291,7 @@ static int ggml_metal_encode_node(
42884291                    ops[1 ] = GGML_OP_MUL;
42894292                    ops[2 ] = GGML_OP_ADD;
42904293
4291-                     for  (n_fuse = 0 ; n_fuse <= 1 ; ++n_fuse) {
4294+                     for  (n_fuse = 0 ; n_fuse <= 1  && idx + n_fuse +  1  < idx_end ; ++n_fuse) {
42924295                        if  (!ggml_can_fuse (gf, idx + n_fuse, ops + n_fuse, 2 )) {
42934296                            break ;
42944297                        }
@@ -6271,7 +6274,11 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
62716274                [encoder pushDebugGroup: [NSString  stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
62726275            }
62736276
6274-             const  int  res = ggml_metal_encode_node (backend, idx, encoder, mem_pool);
6277+             const  int  res = ggml_metal_encode_node (backend, idx, node_end, encoder, mem_pool);
6278+             if  (idx + res > node_end) {
6279+                 GGML_ABORT (" fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s " 
6280+                         " https://github.com/ggml-org/llama.cpp/pull/14849" 
6281+             }
62756282
62766283            if  (should_capture) {
62776284                [encoder popDebugGroup ];
0 commit comments