@@ -591,13 +591,16 @@ static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) {
591591    return  buf;
592592}
593593
594+ struct  ggml_metal_command_buffer {
595+     id <MTLCommandBuffer > obj;
596+ 
597+     struct  ggml_metal_heap * heap;
598+ };
599+ 
594600struct  ggml_backend_metal_context {
595601    id <MTLDevice >       device;
596602    id <MTLCommandQueue > queue;
597603
598-     //  TODO: create heap per command buffer
599-     struct  ggml_metal_heap * heap;
600- 
601604    dispatch_queue_t  d_queue;
602605
603606    struct  ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
@@ -620,7 +623,8 @@ static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) {
620623    void  (^encode_async)(size_t  ith);
621624
622625    //  n_cb command buffers + 1 used by the main thread
623-     id <MTLCommandBuffer > command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1 ];
626+     // id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
627+     struct  ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1 ];
624628
625629    //  abort ggml_metal_graph_compute if callback returns true
626630    ggml_abort_callback abort_callback;
@@ -822,8 +826,6 @@ @implementation GGMLMetalClass
822826
823827    ctx->d_queue  = dispatch_queue_create (" ggml-metal"  , DISPATCH_QUEUE_CONCURRENT);
824828
825-     ctx->heap  = ggml_metal_heap_init (device, 1024 *1024 );
826- 
827829    //  load library
828830    if  (ctx_dev->mtl_library  == nil ) {
829831        ctx_dev->mtl_library  = ggml_metal_load_library (device, ctx_dev->use_bfloat );
@@ -877,7 +879,11 @@ @implementation GGMLMetalClass
877879    ctx->gf  = nil ;
878880    ctx->encode_async  = nil ;
879881    for  (int  i = 0 ; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
880-         ctx->command_buffers [i] = nil ;
882+         ctx->cmd_bufs [i].obj  = nil ;
883+ 
884+         //  create 1MB heaps per command buffer
885+         //  these can be resized during compute when necessary
886+         ctx->cmd_bufs [i].heap  = ggml_metal_heap_init (device, 1024 *1024 );
881887    }
882888
883889#if  TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
@@ -1268,7 +1274,11 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
12681274
12691275    [ctx->queue release ];
12701276
1271-     ggml_metal_heap_free (ctx->heap );
1277+     // ggml_metal_heap_free(ctx->heap);
1278+     for  (int  i = 0 ; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
1279+         [ctx->cmd_bufs[i].obj release ];
1280+         ggml_metal_heap_free (ctx->cmd_bufs [i].heap );
1281+     }
12721282
12731283    dispatch_release (ctx->d_queue );
12741284
@@ -4711,25 +4721,25 @@ static enum ggml_status ggml_metal_graph_compute(
47114721        }
47124722
47134723        //  the main thread commits the first few commands immediately
4714-         //  command_buffer [n_cb]
4724+         //  cmd_buf [n_cb]
47154725        {
4716-             id <MTLCommandBuffer > command_buffer  = [ctx->queue commandBufferWithUnretainedReferences ];
4717-             ctx->command_buffers [n_cb] = command_buffer ;
4726+             id <MTLCommandBuffer > cmd_buf  = [ctx->queue commandBufferWithUnretainedReferences ];
4727+             ctx->cmd_bufs [n_cb]. obj  = cmd_buf ;
47184728
4719-             [command_buffer  enqueue ];
4729+             [cmd_buf  enqueue ];
47204730            ctx->encode_async (n_cb);
47214731        }
47224732
47234733        //  prepare the rest of the command buffers asynchronously
4724-         //  command_buffer [0.. n_cb)
4734+         //  cmd_buf [0.. n_cb)
47254735        for  (int  cb_idx = 0 ; cb_idx < n_cb; ++cb_idx) {
4726-             id <MTLCommandBuffer > command_buffer  = [ctx->queue commandBufferWithUnretainedReferences ];
4727-             ctx->command_buffers [cb_idx] = command_buffer ;
4736+             id <MTLCommandBuffer > cmd_buf  = [ctx->queue commandBufferWithUnretainedReferences ];
4737+             ctx->cmd_bufs [cb_idx]. obj  = cmd_buf ;
47284738
47294739            //  always enqueue the first two command buffers
47304740            //  enqueue all of the command buffers if we don't need to abort
47314741            if  (cb_idx < 2  || ctx->abort_callback  == NULL ) {
4732-                 [command_buffer  enqueue ];
4742+                 [cmd_buf  enqueue ];
47334743            }
47344744        }
47354745
@@ -4738,40 +4748,39 @@ static enum ggml_status ggml_metal_graph_compute(
47384748        //  wait for completion and check status of each command buffer
47394749        //  needed to detect if the device ran out-of-memory for example (#1881)
47404750        {
4741-             id <MTLCommandBuffer > command_buffer  = ctx->command_buffers [n_cb];
4742-             [command_buffer  waitUntilCompleted ];
4751+             id <MTLCommandBuffer > cmd_buf  = ctx->cmd_bufs [n_cb]. obj ;
4752+             [cmd_buf  waitUntilCompleted ];
47434753
4744-             //  TODO: free main cb  heap
4754+             ggml_metal_heap_reset (ctx-> cmd_bufs [n_cb]. heap ); 
47454755
4746-             MTLCommandBufferStatus  status = [command_buffer  status ];
4756+             MTLCommandBufferStatus  status = [cmd_buf  status ];
47474757            if  (status != MTLCommandBufferStatusCompleted ) {
47484758                GGML_LOG_INFO (" %s : command buffer %d  failed with status %lu \n "  , __func__, n_cb, status);
47494759                if  (status == MTLCommandBufferStatusError ) {
4750-                     GGML_LOG_INFO (" error: %s \n "  , [[command_buffer  error ].localizedDescription UTF8String ]);
4760+                     GGML_LOG_INFO (" error: %s \n "  , [[cmd_buf  error ].localizedDescription UTF8String ]);
47514761                }
47524762
47534763                return  GGML_STATUS_FAILED;
47544764            }
47554765        }
47564766
47574767        for  (int  i = 0 ; i < n_cb; ++i) {
4758-             id <MTLCommandBuffer > command_buffer  = ctx->command_buffers [i];
4759-             [command_buffer  waitUntilCompleted ];
4768+             id <MTLCommandBuffer > cmd_buf  = ctx->cmd_bufs [i]. obj ;
4769+             [cmd_buf  waitUntilCompleted ];
47604770
4761-             //  TODO: per command buffer heap
4762-             ggml_metal_heap_reset (ctx->heap );
4771+             ggml_metal_heap_reset (ctx->cmd_bufs [i].heap );
47634772
4764-             MTLCommandBufferStatus  status = [command_buffer  status ];
4773+             MTLCommandBufferStatus  status = [cmd_buf  status ];
47654774            if  (status != MTLCommandBufferStatusCompleted ) {
47664775                GGML_LOG_INFO (" %s : command buffer %d  failed with status %lu \n "  , __func__, i, status);
47674776                if  (status == MTLCommandBufferStatusError ) {
4768-                     GGML_LOG_INFO (" error: %s \n "  , [[command_buffer  error ].localizedDescription UTF8String ]);
4777+                     GGML_LOG_INFO (" error: %s \n "  , [[cmd_buf  error ].localizedDescription UTF8String ]);
47694778                }
47704779
47714780                return  GGML_STATUS_FAILED;
47724781            }
47734782
4774-             id <MTLCommandBuffer > next_buffer = (i + 1  < n_cb ? ctx->command_buffers [i + 1 ] : nil );
4783+             id <MTLCommandBuffer > next_buffer = (i + 1  < n_cb ? ctx->cmd_bufs [i + 1 ]. obj  : nil );
47754784            if  (!next_buffer) {
47764785                continue ;
47774786            }
@@ -5154,12 +5163,13 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51545163
51555164        const  int  n_nodes_per_cb = ctx->n_nodes_per_cb ;
51565165
5157-         id <MTLCommandBuffer > command_buffer  = ctx->command_buffers [cb_idx];
5166+         id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [cb_idx].obj ;
5167+         struct  ggml_metal_heap * heap = ctx->cmd_bufs [cb_idx].heap ;
51585168
51595169        int  n_try = 3 ;
51605170
51615171        while  (n_try-- > 0 ) {
5162-             id <MTLComputeCommandEncoder > encoder = [command_buffer  computeCommandEncoder ];
5172+             id <MTLComputeCommandEncoder > encoder = [cmd_buf  computeCommandEncoder ];
51635173
51645174            int  node_start = 0 ;
51655175            int  node_end   = n_nodes_0;
@@ -5176,7 +5186,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51765186                    [encoder pushDebugGroup: [NSString  stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
51775187                }
51785188
5179-                 ggml_metal_encode_node (backend, idx, encoder, ctx-> heap );
5189+                 ggml_metal_encode_node (backend, idx, encoder, heap);
51805190
51815191                if  (should_capture) {
51825192                    [encoder popDebugGroup ];
@@ -5185,22 +5195,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51855195
51865196            [encoder endEncoding ];
51875197
5188-             if  (ctx-> heap ->fail  == 0 ) {
5198+             if  (heap->fail  == 0 ) {
51895199                break ;
51905200            }
51915201
5192-             const  size_t  need = ctx-> heap ->need ;
5202+             const  size_t  need = heap->need ;
51935203
51945204            GGML_LOG_INFO (" %s : increasing heap size to %zu \n "  , __func__, need);
51955205
5196-             if  (!ggml_metal_heap_resize (ctx-> heap , need)) {
5206+             if  (!ggml_metal_heap_resize (heap, need)) {
51975207                GGML_LOG_ERROR (" %s : failed to increase heap size to %zu \n "  , __func__, need);
51985208                break ;
51995209            }
52005210        }
52015211
52025212        if  (cb_idx < 2  || ctx->abort_callback  == NULL ) {
5203-             [command_buffer  commit ];
5213+             [cmd_buf  commit ];
52045214        }
52055215    });
52065216}
0 commit comments