@@ -591,13 +591,16 @@ static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) {
591591 return buf;
592592}
593593
594+ struct ggml_metal_command_buffer {
595+ id <MTLCommandBuffer > obj;
596+
597+ struct ggml_metal_heap * heap;
598+ };
599+
594600struct ggml_backend_metal_context {
595601 id <MTLDevice > device;
596602 id <MTLCommandQueue > queue;
597603
598- // TODO: create heap per command buffer
599- struct ggml_metal_heap * heap;
600-
601604 dispatch_queue_t d_queue;
602605
603606 struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
@@ -620,7 +623,8 @@ static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) {
620623 void (^encode_async)(size_t ith);
621624
622625 // n_cb command buffers + 1 used by the main thread
623- id <MTLCommandBuffer > command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1 ];
626+ // id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
627+ struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1 ];
624628
625629 // abort ggml_metal_graph_compute if callback returns true
626630 ggml_abort_callback abort_callback;
@@ -822,8 +826,6 @@ @implementation GGMLMetalClass
822826
823827 ctx->d_queue = dispatch_queue_create (" ggml-metal" , DISPATCH_QUEUE_CONCURRENT);
824828
825- ctx->heap = ggml_metal_heap_init (device, 1024 *1024 );
826-
827829 // load library
828830 if (ctx_dev->mtl_library == nil ) {
829831 ctx_dev->mtl_library = ggml_metal_load_library (device, ctx_dev->use_bfloat );
@@ -877,7 +879,11 @@ @implementation GGMLMetalClass
877879 ctx->gf = nil ;
878880 ctx->encode_async = nil ;
879881 for (int i = 0 ; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
880- ctx->command_buffers [i] = nil ;
882+ ctx->cmd_bufs [i].obj = nil ;
883+
884+ // create 1MB heaps per command buffer
885+ // these can be resized during compute when necessary
886+ ctx->cmd_bufs [i].heap = ggml_metal_heap_init (device, 1024 *1024 );
881887 }
882888
883889#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
@@ -1268,7 +1274,11 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
12681274
12691275 [ctx->queue release ];
12701276
1271- ggml_metal_heap_free (ctx->heap );
1277+ // ggml_metal_heap_free(ctx->heap);
1278+ for (int i = 0 ; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
1279+ [ctx->cmd_bufs[i].obj release ];
1280+ ggml_metal_heap_free (ctx->cmd_bufs [i].heap );
1281+ }
12721282
12731283 dispatch_release (ctx->d_queue );
12741284
@@ -4712,25 +4722,25 @@ static enum ggml_status ggml_metal_graph_compute(
47124722 }
47134723
47144724 // the main thread commits the first few commands immediately
4715- // command_buffer [n_cb]
4725+ // cmd_buf [n_cb]
47164726 {
4717- id <MTLCommandBuffer > command_buffer = [ctx->queue commandBufferWithUnretainedReferences ];
4718- ctx->command_buffers [n_cb] = command_buffer ;
4727+ id <MTLCommandBuffer > cmd_buf = [ctx->queue commandBufferWithUnretainedReferences ];
4728+ ctx->cmd_bufs [n_cb]. obj = cmd_buf ;
47194729
4720- [command_buffer enqueue ];
4730+ [cmd_buf enqueue ];
47214731 ctx->encode_async (n_cb);
47224732 }
47234733
47244734 // prepare the rest of the command buffers asynchronously
4725- // command_buffer [0.. n_cb)
4735+ // cmd_buf [0.. n_cb)
47264736 for (int cb_idx = 0 ; cb_idx < n_cb; ++cb_idx) {
4727- id <MTLCommandBuffer > command_buffer = [ctx->queue commandBufferWithUnretainedReferences ];
4728- ctx->command_buffers [cb_idx] = command_buffer ;
4737+ id <MTLCommandBuffer > cmd_buf = [ctx->queue commandBufferWithUnretainedReferences ];
4738+ ctx->cmd_bufs [cb_idx]. obj = cmd_buf ;
47294739
47304740 // always enqueue the first two command buffers
47314741 // enqueue all of the command buffers if we don't need to abort
47324742 if (cb_idx < 2 || ctx->abort_callback == NULL ) {
4733- [command_buffer enqueue ];
4743+ [cmd_buf enqueue ];
47344744 }
47354745 }
47364746
@@ -4739,40 +4749,39 @@ static enum ggml_status ggml_metal_graph_compute(
47394749 // wait for completion and check status of each command buffer
47404750 // needed to detect if the device ran out-of-memory for example (#1881)
47414751 {
4742- id <MTLCommandBuffer > command_buffer = ctx->command_buffers [n_cb];
4743- [command_buffer waitUntilCompleted ];
4752+ id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [n_cb]. obj ;
4753+ [cmd_buf waitUntilCompleted ];
47444754
4745- // TODO: free main cb heap
4755+ ggml_metal_heap_reset (ctx-> cmd_bufs [n_cb]. heap );
47464756
4747- MTLCommandBufferStatus status = [command_buffer status ];
4757+ MTLCommandBufferStatus status = [cmd_buf status ];
47484758 if (status != MTLCommandBufferStatusCompleted ) {
47494759 GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, n_cb, status);
47504760 if (status == MTLCommandBufferStatusError ) {
4751- GGML_LOG_INFO (" error: %s \n " , [[command_buffer error ].localizedDescription UTF8String ]);
4761+ GGML_LOG_INFO (" error: %s \n " , [[cmd_buf error ].localizedDescription UTF8String ]);
47524762 }
47534763
47544764 return GGML_STATUS_FAILED;
47554765 }
47564766 }
47574767
47584768 for (int i = 0 ; i < n_cb; ++i) {
4759- id <MTLCommandBuffer > command_buffer = ctx->command_buffers [i];
4760- [command_buffer waitUntilCompleted ];
4769+ id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [i]. obj ;
4770+ [cmd_buf waitUntilCompleted ];
47614771
4762- // TODO: per command buffer heap
4763- ggml_metal_heap_reset (ctx->heap );
4772+ ggml_metal_heap_reset (ctx->cmd_bufs [i].heap );
47644773
4765- MTLCommandBufferStatus status = [command_buffer status ];
4774+ MTLCommandBufferStatus status = [cmd_buf status ];
47664775 if (status != MTLCommandBufferStatusCompleted ) {
47674776 GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, i, status);
47684777 if (status == MTLCommandBufferStatusError ) {
4769- GGML_LOG_INFO (" error: %s \n " , [[command_buffer error ].localizedDescription UTF8String ]);
4778+ GGML_LOG_INFO (" error: %s \n " , [[cmd_buf error ].localizedDescription UTF8String ]);
47704779 }
47714780
47724781 return GGML_STATUS_FAILED;
47734782 }
47744783
4775- id <MTLCommandBuffer > next_buffer = (i + 1 < n_cb ? ctx->command_buffers [i + 1 ] : nil );
4784+ id <MTLCommandBuffer > next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs [i + 1 ]. obj : nil );
47764785 if (!next_buffer) {
47774786 continue ;
47784787 }
@@ -5155,12 +5164,13 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51555164
51565165 const int n_nodes_per_cb = ctx->n_nodes_per_cb ;
51575166
5158- id <MTLCommandBuffer > command_buffer = ctx->command_buffers [cb_idx];
5167+ id <MTLCommandBuffer > cmd_buf = ctx->cmd_bufs [cb_idx].obj ;
5168+ struct ggml_metal_heap * heap = ctx->cmd_bufs [cb_idx].heap ;
51595169
51605170 int n_try = 3 ;
51615171
51625172 while (n_try-- > 0 ) {
5163- id <MTLComputeCommandEncoder > encoder = [command_buffer computeCommandEncoder ];
5173+ id <MTLComputeCommandEncoder > encoder = [cmd_buf computeCommandEncoder ];
51645174
51655175 int node_start = 0 ;
51665176 int node_end = n_nodes_0;
@@ -5177,7 +5187,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51775187 [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
51785188 }
51795189
5180- ggml_metal_encode_node (backend, idx, encoder, ctx-> heap );
5190+ ggml_metal_encode_node (backend, idx, encoder, heap);
51815191
51825192 if (should_capture) {
51835193 [encoder popDebugGroup ];
@@ -5186,22 +5196,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51865196
51875197 [encoder endEncoding ];
51885198
5189- if (ctx-> heap ->fail == 0 ) {
5199+ if (heap->fail == 0 ) {
51905200 break ;
51915201 }
51925202
5193- const size_t need = ctx-> heap ->need ;
5203+ const size_t need = heap->need ;
51945204
51955205 GGML_LOG_INFO (" %s : increasing heap size to %zu \n " , __func__, need);
51965206
5197- if (!ggml_metal_heap_resize (ctx-> heap , need)) {
5207+ if (!ggml_metal_heap_resize (heap, need)) {
51985208 GGML_LOG_ERROR (" %s : failed to increase heap size to %zu \n " , __func__, need);
51995209 break ;
52005210 }
52015211 }
52025212
52035213 if (cb_idx < 2 || ctx->abort_callback == NULL ) {
5204- [command_buffer commit ];
5214+ [cmd_buf commit ];
52055215 }
52065216 });
52075217}
0 commit comments