236236    bool  should_capture_next_compute;
237237    bool  capture_started;
238238
239+     id <MTLCaptureScope > cap_scope;
240+ 
239241    //  abort ggml_metal_graph_compute if callback returns true
240242    ggml_abort_callback abort_callback;
241243    void  *              abort_callback_data;
@@ -459,6 +461,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
459461    ctx->should_capture_next_compute  = false ;
460462    ctx->capture_started  = false ;
461463
464+     ctx->cap_scope  = nil ;
465+ 
462466#if  TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
463467    if  (@available (macOS 10.12 , iOS 16.0 , *)) {
464468        GGML_METAL_LOG_INFO (" %s : recommendedMaxWorkingSetSize  = %8.2f  MB\n "  , __func__, ctx->device .recommendedMaxWorkingSetSize  / 1e6 );
@@ -887,17 +891,21 @@ static enum ggml_status ggml_metal_graph_compute(
887891    //  create multiple command buffers and enqueue them
888892    //  then, we encode the graph into the command buffers in parallel
889893
890-     const  int  n_nodes = gf->n_nodes ;
894+     const  int  n_nodes_0 = MIN (64 , gf->n_nodes );
895+     const  int  n_nodes_1 = gf->n_nodes  - n_nodes_0;
891896    const  int  n_cb = ctx->n_cb ;
892-     const  int  n_nodes_per_cb = (n_nodes  + n_cb - 1 ) / n_cb;
897+     const  int  n_nodes_per_cb = (n_nodes_1  + n_cb - 1 ) / n_cb;
893898
894899    const  bool  should_capture = ctx->should_capture_next_compute ;
895900    if  (should_capture) {
896901        ctx->should_capture_next_compute  = false ;
897902
898903        if  (!ctx->capture_started ) {
904+             //  create capture scope
905+             ctx->cap_scope  = [[MTLCaptureManager  sharedCaptureManager ] newCaptureScopeWithDevice: ctx->device];
906+ 
899907            MTLCaptureDescriptor  * descriptor = [MTLCaptureDescriptor  new ];
900-             descriptor.captureObject  = ctx->queue ;
908+             descriptor.captureObject  = ctx->cap_scope ;
901909            descriptor.destination  = MTLCaptureDestinationGPUTraceDocument ;
902910            descriptor.outputURL  = [NSURL  fileURLWithPath: [NSString  stringWithFormat: @" /tmp/perf-metal.gputrace"  ]];
903911
@@ -906,26 +914,17 @@ static enum ggml_status ggml_metal_graph_compute(
906914                GGML_METAL_LOG_ERROR (" %s : error: unable to start capture '%s '\n "  , __func__, [[error localizedDescription ] UTF8String ]);
907915                GGML_ABORT (" capture failed"  );
908916            } else  {
917+                 [ctx->cap_scope beginScope ];
909918                ctx->capture_started  = true ;
910919            }
911920        }
912921    }
913922
914-     id <MTLCommandBuffer > command_buffer_builder[n_cb];
915-     for  (int  cb_idx = 0 ; cb_idx < n_cb; ++cb_idx) {
916-         id <MTLCommandBuffer > command_buffer  = [ctx->queue commandBufferWithUnretainedReferences ];
917-         command_buffer_builder[cb_idx] = command_buffer;
918- 
919-         //  always enqueue the first two command buffers
920-         //  enqueue all of the command buffers if we don't need to abort
921-         if  (cb_idx < 2  || ctx->abort_callback  == NULL ) {
922-             [command_buffer enqueue ];
923-         }
924-     }
925- 
926-     const  id <MTLCommandBuffer > *command_buffers = command_buffer_builder;
923+     id <MTLCommandBuffer > command_buffer_builder[n_cb + 1 ];
924+     const  id <MTLCommandBuffer > * command_buffers = command_buffer_builder;
927925
928-     dispatch_apply (n_cb, ctx->d_queue , ^(size_t  iter) {
926+     // dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
927+     void  (^helper)(size_t  iter) = ^(size_t  iter) {
929928        const  int  cb_idx = iter;
930929
931930        size_t  offs_src0 = 0 ;
@@ -936,8 +935,12 @@ static enum ggml_status ggml_metal_graph_compute(
936935        id <MTLCommandBuffer > command_buffer  = command_buffers[cb_idx];
937936        id <MTLComputeCommandEncoder > encoder = [command_buffer computeCommandEncoderWithDescriptor:  edesc];
938937
939-         const  int  node_start =                                      (cb_idx + 0 ) * n_nodes_per_cb;
940-         const  int  node_end   = MIN ((cb_idx == n_cb - 1 ) ? n_nodes : (cb_idx + 1 ) * n_nodes_per_cb, n_nodes);
938+         int  node_start = 0 ;
939+         int  node_end   = n_nodes_0;
940+         if  ((int ) iter < n_cb) {
941+             node_start = n_nodes_0 + (                                       (cb_idx + 0 ) * n_nodes_per_cb);
942+             node_end   = n_nodes_0 + (MIN ((cb_idx == n_cb - 1 ) ? n_nodes_1 : (cb_idx + 1 ) * n_nodes_per_cb, n_nodes_1));
943+         }
941944
942945        for  (int  i = node_start; i < node_end; ++i) {
943946            if  (i == -1 ) {
@@ -3037,11 +3040,36 @@ static enum ggml_status ggml_metal_graph_compute(
30373040        if  (cb_idx < 2  || ctx->abort_callback  == NULL ) {
30383041            [command_buffer commit ];
30393042        }
3040-     });
3043+     };
3044+ 
3045+     {
3046+         id <MTLCommandBuffer > command_buffer = [ctx->queue commandBufferWithUnretainedReferences ];
3047+         command_buffer_builder[n_cb] = command_buffer;
3048+         [command_buffer enqueue ];
3049+         helper (n_cb);
3050+     }
3051+ 
3052+     for  (int  cb_idx = 0 ; cb_idx < n_cb; ++cb_idx) {
3053+         id <MTLCommandBuffer > command_buffer  = [ctx->queue commandBufferWithUnretainedReferences ];
3054+         command_buffer_builder[cb_idx] = command_buffer;
3055+ 
3056+         //  always enqueue the first two command buffers
3057+         //  enqueue all of the command buffers if we don't need to abort
3058+         if  (cb_idx < 2  || ctx->abort_callback  == NULL ) {
3059+             [command_buffer enqueue ];
3060+         }
3061+     }
3062+ 
3063+     dispatch_apply (n_cb, ctx->d_queue , helper);
30413064
30423065    //  Wait for completion and check status of each command buffer
30433066    //  needed to detect if the device ran out-of-memory for example (#1881)
30443067
3068+     {
3069+         id <MTLCommandBuffer > command_buffer = command_buffers[n_cb];
3070+         [command_buffer waitUntilCompleted ];
3071+     }
3072+ 
30453073    for  (int  i = 0 ; i < n_cb; ++i) {
30463074        id <MTLCommandBuffer > command_buffer = command_buffers[i];
30473075        [command_buffer waitUntilCompleted ];
@@ -3075,6 +3103,7 @@ static enum ggml_status ggml_metal_graph_compute(
30753103    }
30763104
30773105    if  (!should_capture && ctx->capture_started ) {
3106+         [ctx->cap_scope endScope ];
30783107        [[MTLCaptureManager  sharedCaptureManager ] stopCapture ];
30793108    }
30803109
0 commit comments