1515//  max memory buffers that can be mapped to the device
1616#define  GGML_METAL_MAX_BUFFERS  64 
1717
18- #define  GGML_METAL_MAX_COMMAND_BUFFERS  128 
18+ //  max number of MTLCommandBuffer used to submit a graph for processing
19+ #define  GGML_METAL_MAX_COMMAND_BUFFERS  8 
1920
2021#ifdef  GGML_METAL_NDEBUG
2122#define  GGML_METAL_LOG (...)
226227};
227228
228229struct  ggml_backend_metal_context {
229-     int  n_cb;
230- 
231230    id <MTLDevice >       device;
232231    id <MTLCommandQueue > queue;
233232
240239    bool  support_simdgroup_reduction;
241240    bool  support_simdgroup_mm;
242241
243-     bool  should_capture_next_compute;
242+     //  capture state
243+     bool  capture_next_compute;
244244    bool  capture_started;
245245
246-     id <MTLCaptureScope > cap_scope;
247- 
248-     id <MTLCommandBuffer > command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1 ];
246+     id <MTLCaptureScope > capture_scope;
249247
250-     int  n_nodes_0;
251-     int  n_nodes_1;
248+     //  command buffer state
249+     int  n_cb;           //  number of extra threads used to submit the command buffers
250+     int  n_nodes_0;      //  number of nodes submitted by the main thread
251+     int  n_nodes_1;      //  remaining number of nodes submitted by the n_cb threads
252252    int  n_nodes_per_cb;
253253
254254    struct  ggml_cgraph * gf;
255255
256+     //  the callback given to the thread pool
257+     //  TODO: ideally, this should be created once, utilizing the command buffer state above
258+     //        for some reason, doing it like this leads to a crash
256259    void  (^encode_async)(size_t  ith);
257260
261+     //  n_cb command buffers + 1 used by the main thread
262+     id <MTLCommandBuffer > command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1 ];
263+ 
258264    //  abort ggml_metal_graph_compute if callback returns true
259265    ggml_abort_callback abort_callback;
260266    void  *              abort_callback_data;
@@ -476,17 +482,16 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
476482    GGML_METAL_LOG_INFO (" %s : simdgroup matrix mul. support = %s \n " support_simdgroup_mm  ? " true" " false" 
477483    GGML_METAL_LOG_INFO (" %s : hasUnifiedMemory              = %s \n " device .hasUnifiedMemory  ? " true" " false" 
478484
479-     ctx->should_capture_next_compute  = false ;
485+     ctx->capture_next_compute  = false ;
480486    ctx->capture_started  = false ;
487+     ctx->capture_scope  = nil ;
481488
482-     ctx->cap_scope  = nil ;
483- 
489+     ctx->gf  = nil ;
490+     ctx-> encode_async  =  nil ; 
484491    for  (int  i = 0 ; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
485492        ctx->command_buffers [i] = nil ;
486493    }
487494
488-     ctx->encode_async  = nil ;
489- 
490495#if  TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
491496    if  (@available (macOS 10.12 , iOS 16.0 , *)) {
492497        GGML_METAL_LOG_INFO (" %s : recommendedMaxWorkingSetSize  = %8.2f  MB\n " device .recommendedMaxWorkingSetSize  / 1e6 );
@@ -3000,31 +3005,37 @@ static void ggml_metal_encode_node(
30003005static  enum  ggml_status ggml_metal_graph_compute (
30013006        struct  ggml_backend_metal_context * ctx,
30023007                       struct  ggml_cgraph * gf) {
3003-     @autoreleasepool {
3004-         //  create multiple command buffers and enqueue them
3005-         //  then, we encode the graph into the command buffers in parallel
3008+     //  number of nodes encoded by the main thread (empirically determined)
3009+     const  int  n_main = 128 ;
30063010
3007-         const  int  n_cb = ctx->n_cb ;
3011+     //  number of threads in addition to the main thread
3012+     const  int  n_cb = ctx->n_cb ;
30083013
3014+     //  submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
3015+     //  the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
3016+     //  while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
3017+     //  each thread creates it's own command buffer and enqueues the ops in parallel
3018+     // 
3019+     //  tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
3020+ 
3021+     @autoreleasepool {
30093022        ctx->gf  = gf;
30103023
3011-         ctx->n_nodes_0  = MIN (128 , gf->n_nodes );
3024+         ctx->n_nodes_0  = MIN (n_main , gf->n_nodes );
30123025        ctx->n_nodes_1  = gf->n_nodes  - ctx->n_nodes_0 ;
30133026
3014-         ctx->n_nodes_per_cb  = (ctx->n_nodes_1  + n_cb - 1 ) / n_cb;
3027+         ctx->n_nodes_per_cb  = (ctx->n_nodes_1  + ctx-> n_cb  - 1 ) / ctx-> n_cb ;
30153028
3016-         // const int64_t t_start = ggml_time_us();
3017- 
3018-         const  bool  should_capture = ctx->should_capture_next_compute ;
3029+         const  bool  should_capture = ctx->capture_next_compute ;
30193030        if  (should_capture) {
3020-             ctx->should_capture_next_compute  = false ;
3031+             ctx->capture_next_compute  = false ;
30213032
30223033            if  (!ctx->capture_started ) {
30233034                //  create capture scope
3024-                 ctx->cap_scope  = [[MTLCaptureManager  sharedCaptureManager ] newCaptureScopeWithDevice: ctx->device];
3035+                 ctx->capture_scope  = [[MTLCaptureManager  sharedCaptureManager ] newCaptureScopeWithDevice: ctx->device];
30253036
30263037                MTLCaptureDescriptor  * descriptor = [MTLCaptureDescriptor  new ];
3027-                 descriptor.captureObject  = ctx->cap_scope ;
3038+                 descriptor.captureObject  = ctx->capture_scope ;
30283039                descriptor.destination  = MTLCaptureDestinationGPUTraceDocument ;
30293040                descriptor.outputURL  = [NSURL  fileURLWithPath: [NSString  stringWithFormat: @" /tmp/perf-metal.gputrace" 
30303041
@@ -3033,7 +3044,7 @@ static enum ggml_status ggml_metal_graph_compute(
30333044                    GGML_METAL_LOG_ERROR (" %s : error: unable to start capture '%s '\n " localizedDescription ] UTF8String ]);
30343045                    GGML_ABORT (" capture failed" 
30353046                } else  {
3036-                     [ctx->cap_scope  beginScope ];
3047+                     [ctx->capture_scope  beginScope ];
30373048                    ctx->capture_started  = true ;
30383049                }
30393050            }
@@ -3055,7 +3066,7 @@ static enum ggml_status ggml_metal_graph_compute(
30553066            int  node_start = 0 ;
30563067            int  node_end   = n_nodes_0;
30573068
3058-             if  (( int ) iter  < n_cb_l) {
3069+             if  (cb_idx  < n_cb_l) {
30593070                node_start = n_nodes_0 + (                                         (cb_idx + 0 ) * n_nodes_per_cb);
30603071                node_end   = n_nodes_0 + (MIN ((cb_idx == n_cb_l - 1 ) ? n_nodes_1 : (cb_idx + 1 ) * n_nodes_per_cb, n_nodes_1));
30613072            }
@@ -3079,17 +3090,20 @@ static enum ggml_status ggml_metal_graph_compute(
30793090            }
30803091        };
30813092
3093+         //  the main thread commits the first few commands immediately
3094+         //  command_buffer[n_cb]
30823095        {
30833096            id <MTLCommandBuffer > command_buffer = [ctx->queue commandBufferWithUnretainedReferences ];
30843097            ctx->command_buffers [n_cb] = command_buffer;
30853098
30863099            [command_buffer enqueue ];
3087- 
30883100            ctx->encode_async (n_cb);
30893101        }
30903102
3103+         //  prepare the rest of the command buffers asynchronously
3104+         //  command_buffer[0.. n_cb)
30913105        for  (int  cb_idx = 0 ; cb_idx < n_cb; ++cb_idx) {
3092-             id <MTLCommandBuffer > command_buffer   = [ctx->queue commandBufferWithUnretainedReferences ];
3106+             id <MTLCommandBuffer > command_buffer = [ctx->queue commandBufferWithUnretainedReferences ];
30933107            ctx->command_buffers [cb_idx] = command_buffer;
30943108
30953109            //  always enqueue the first two command buffers
@@ -3101,14 +3115,8 @@ static enum ggml_status ggml_metal_graph_compute(
31013115
31023116        dispatch_apply (n_cb, ctx->d_queue , ctx->encode_async );
31033117
3104-         // {
3105-         //     const int64_t t_end = ggml_time_us();
3106-         //     //printf("time to encode: %d us, n_cb = %d\n", (int) (t_end - t_start), n_cb);
3107-         // }
3108- 
3109-         //  Wait for completion and check status of each command buffer
3118+         //  wait for completion and check status of each command buffer
31103119        //  needed to detect if the device ran out-of-memory for example (#1881)
3111- 
31123120        {
31133121            id <MTLCommandBuffer > command_buffer = ctx->command_buffers [n_cb];
31143122            [command_buffer waitUntilCompleted ];
@@ -3143,7 +3151,7 @@ static enum ggml_status ggml_metal_graph_compute(
31433151                continue ;
31443152            }
31453153
3146-             bool  next_queued = ([next_buffer status ] != MTLCommandBufferStatusNotEnqueued );
3154+             const   bool  next_queued = ([next_buffer status ] != MTLCommandBufferStatusNotEnqueued );
31473155            if  (next_queued) {
31483156                continue ;
31493157            }
@@ -3156,13 +3164,8 @@ static enum ggml_status ggml_metal_graph_compute(
31563164            [next_buffer commit ];
31573165        }
31583166
3159-         // {
3160-         //     const int64_t t_end = ggml_time_us();
3161-         //     printf("time to compute: %d us\n", (int)(t_end - t_start));
3162-         // }
3163- 
31643167        if  (!should_capture && ctx->capture_started ) {
3165-             [ctx->cap_scope  endScope ];
3168+             [ctx->capture_scope  endScope ];
31663169            [[MTLCaptureManager  sharedCaptureManager ] stopCapture ];
31673170        }
31683171    }
@@ -3514,7 +3517,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
35143517        }
35153518    }
35163519
3517-     //  TODO: setting encode_async here causes crash. why?
3520+     //  TODO: setting encode_async here causes crash during the next ggml_metal_graph_compute call . why?
35183521    // ctx->encode_async = ^(size_t iter) {
35193522    //     ...
35203523    // };
@@ -3598,7 +3601,7 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
35983601    GGML_ASSERT (ggml_backend_is_metal (backend));
35993602
36003603    struct  ggml_backend_metal_context * ctx = (struct  ggml_backend_metal_context *)backend->context ;
3601-     ctx->should_capture_next_compute  = true ;
3604+     ctx->capture_next_compute  = true ;
36023605}
36033606
36043607GGML_CALL ggml_backend_t  ggml_backend_reg_metal_init (const  char  * params, void  * user_data); //  silence warning
0 commit comments