@@ -438,6 +438,7 @@ @implementation GGMLMetalClass
438438    ctx->capture_scope  = nil ;
439439
440440    ctx->gf  = nil ;
441+     Block_release (ctx->encode_async );
441442    ctx->encode_async  = nil ;
442443    for  (int  i = 0 ; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
443444        ctx->command_buffers [i] = nil ;
@@ -3000,46 +3001,6 @@ static enum ggml_status ggml_metal_graph_compute(
30003001            }
30013002        }
30023003
3003-         //  TODO: how to avoid this allocation? I tried initializing it in ggml_backend_metal_set_n_cb but it crashes.
3004-         ctx->encode_async  = ^(size_t  iter) {
3005-             const  int  cb_idx = iter;
3006-             const  int  n_cb_l = ctx->n_cb ;
3007- 
3008-             const  int  n_nodes_0 = ctx->n_nodes_0 ;
3009-             const  int  n_nodes_1 = ctx->n_nodes_1 ;
3010- 
3011-             const  int  n_nodes_per_cb = ctx->n_nodes_per_cb ;
3012- 
3013-             id <MTLCommandBuffer > command_buffer  = ctx->command_buffers [cb_idx];
3014-             id <MTLComputeCommandEncoder > encoder = [command_buffer computeCommandEncoder ];
3015- 
3016-             int  node_start = 0 ;
3017-             int  node_end   = n_nodes_0;
3018- 
3019-             if  (cb_idx < n_cb_l) {
3020-                 node_start = n_nodes_0 + (                                         (cb_idx + 0 ) * n_nodes_per_cb);
3021-                 node_end   = n_nodes_0 + (MIN ((cb_idx == n_cb_l - 1 ) ? n_nodes_1 : (cb_idx + 1 ) * n_nodes_per_cb, n_nodes_1));
3022-             }
3023- 
3024-             for  (int  idx = node_start; idx < node_end; ++idx) {
3025-                 if  (should_capture) {
3026-                     [encoder pushDebugGroup: [NSString  stringWithCString: ggml_op_desc (ggml_graph_node (gf, idx)) encoding: NSUTF8StringEncoding]];
3027-                 }
3028- 
3029-                 ggml_metal_encode_node (ctx, idx, encoder);
3030- 
3031-                 if  (should_capture) {
3032-                     [encoder popDebugGroup ];
3033-                 }
3034-             }
3035- 
3036-             [encoder endEncoding ];
3037- 
3038-             if  (cb_idx < 2  || ctx->abort_callback  == NULL ) {
3039-                 [command_buffer commit ];
3040-             }
3041-         };
3042- 
30433004        //  the main thread commits the first few commands immediately
30443005        //  command_buffer[n_cb]
30453006        {
@@ -3468,10 +3429,46 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
34683429        }
34693430    }
34703431
3471-     //  TODO: setting encode_async here causes crash during the next ggml_metal_graph_compute call. why?
3472-     // ctx->encode_async = ^(size_t iter) {
3473-     //     ...
3474-     // };
3432+     ctx->encode_async  = Block_copy (^(size_t  iter) {
3433+         const  int  cb_idx = iter;
3434+         const  int  n_cb_l = ctx->n_cb ;
3435+ 
3436+         const  int  n_nodes_0 = ctx->n_nodes_0 ;
3437+         const  int  n_nodes_1 = ctx->n_nodes_1 ;
3438+ 
3439+         const  int  n_nodes_per_cb = ctx->n_nodes_per_cb ;
3440+ 
3441+         id <MTLCommandBuffer > command_buffer  = ctx->command_buffers [cb_idx];
3442+         id <MTLComputeCommandEncoder > encoder = [command_buffer computeCommandEncoder ];
3443+ 
3444+         int  node_start = 0 ;
3445+         int  node_end   = n_nodes_0;
3446+ 
3447+         if  (cb_idx < n_cb_l) {
3448+             node_start = n_nodes_0 + (                                         (cb_idx + 0 ) * n_nodes_per_cb);
3449+             node_end   = n_nodes_0 + (MIN ((cb_idx == n_cb_l - 1 ) ? n_nodes_1 : (cb_idx + 1 ) * n_nodes_per_cb, n_nodes_1));
3450+         }
3451+ 
3452+         const  bool  should_capture = ctx->capture_next_compute ;
3453+ 
3454+         for  (int  idx = node_start; idx < node_end; ++idx) {
3455+             if  (should_capture) {
3456+                 [encoder pushDebugGroup: [NSString  stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
3457+             }
3458+ 
3459+             ggml_metal_encode_node (ctx, idx, encoder);
3460+ 
3461+             if  (should_capture) {
3462+                 [encoder popDebugGroup ];
3463+             }
3464+         }
3465+ 
3466+         [encoder endEncoding ];
3467+ 
3468+         if  (cb_idx < 2  || ctx->abort_callback  == NULL ) {
3469+             [command_buffer commit ];
3470+         }
3471+     });
34753472}
34763473
34773474static  struct  ggml_backend_i ggml_backend_metal_i = {
0 commit comments