1919// max number of MTLCommandBuffer used to submit a graph for processing
2020#define GGML_METAL_MAX_COMMAND_BUFFERS 8
2121
22+ // max number of buffers that can be allocated on the heap per command buffer
23+ #define GGML_METAL_MAX_HEAP_BUFFERS 64
24+
2225#ifndef TARGET_OS_VISION
2326#define TARGET_OS_VISION 0
2427#endif
@@ -468,9 +471,18 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
468471 GGML_METAL_KERNEL_TYPE_COUNT
469472};
470473
474+ struct ggml_backend_metal_heap {
475+ int n;
476+
477+ id <MTLHeap > obj;
478+ id <MTLBuffer > bufs[GGML_METAL_MAX_HEAP_BUFFERS];
479+ };
480+
471481struct ggml_backend_metal_context {
472482 id <MTLCommandQueue > queue;
473- id <MTLHeap > heap;
483+
484+ // TODO: create heap per command buffer
485+ struct ggml_backend_metal_heap heap;
474486
475487 dispatch_queue_t d_queue;
476488
@@ -702,7 +714,12 @@ @implementation GGMLMetalClass
702714 heapDescriptor.cpuCacheMode = MTLCPUCacheModeDefaultCache ;
703715 heapDescriptor.size = 32 *1024 *1024 ;
704716
705- ctx->heap = [device newHeapWithDescriptor: heapDescriptor];
717+ ctx->heap .n = 0 ;
718+
719+ ctx->heap .obj = [device newHeapWithDescriptor: heapDescriptor];
720+ for (int i = 0 ; i < GGML_METAL_MAX_HEAP_BUFFERS; ++i) {
721+ ctx->heap .bufs [i] = nil ;
722+ }
706723
707724 [heapDescriptor release ];
708725 }
@@ -1149,8 +1166,8 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
11491166
11501167 Block_release (ctx->encode_async );
11511168
1152- [ctx->queue release ];
1153- [ctx->heap release ];
1169+ [ctx->queue release ];
1170+ [ctx->heap.obj release ];
11541171
11551172 dispatch_release (ctx->d_queue );
11561173
@@ -1454,7 +1471,7 @@ static void ggml_metal_encode_node(
14541471 ggml_backend_t backend,
14551472 int idx,
14561473 id <MTLComputeCommandEncoder > encoder,
1457- id < MTLHeap > heap) {
1474+ struct ggml_backend_metal_heap * heap) {
14581475 struct ggml_backend_metal_context * ctx = backend->context ;
14591476 struct ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
14601477
@@ -2147,7 +2164,11 @@ static void ggml_metal_encode_node(
21472164 /* .nb3 =*/ nb03,
21482165 };
21492166
2150- id <MTLBuffer > id_src0h = [heap newBufferWithLength: ggml_nbytes (src0) options: MTLResourceStorageModePrivate ];
2167+ id <MTLBuffer > id_src0h = [heap->obj newBufferWithLength: ggml_nbytes (src0) options: MTLResourceStorageModePrivate ];
2168+
2169+ // save a reference to the heap-allocated buffer
2170+ // TODO: simplify and check for available resources
2171+ heap->bufs [heap->n++] = id_src0h;
21512172
21522173 if (src0->type == GGML_TYPE_F16) {
21532174 [encoder setComputePipelineState: ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
@@ -4620,6 +4641,8 @@ static enum ggml_status ggml_metal_graph_compute(
46204641 id <MTLCommandBuffer > command_buffer = ctx->command_buffers [n_cb];
46214642 [command_buffer waitUntilCompleted ];
46224643
4644+ // TODO: free main cb heap
4645+
46234646 MTLCommandBufferStatus status = [command_buffer status ];
46244647 if (status != MTLCommandBufferStatusCompleted ) {
46254648 GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, n_cb, status);
@@ -4635,6 +4658,22 @@ static enum ggml_status ggml_metal_graph_compute(
46354658 id <MTLCommandBuffer > command_buffer = ctx->command_buffers [i];
46364659 [command_buffer waitUntilCompleted ];
46374660
4661+ // free buffers from the heap
4662+ {
4663+ size_t size_allocated = [ctx->heap.obj currentAllocatedSize ];
4664+ size_t size_used = [ctx->heap.obj usedSize ];
4665+ GGML_LOG_INFO (" %s : command buffer %d , allocated = %zu , used = %zu , n = %d \n " , __func__, i, size_allocated, size_used, ctx->heap .n );
4666+
4667+ for (int j = 0 ; j < ctx->heap .n ; ++j) {
4668+ id <MTLBuffer > buf = ctx->heap .bufs [j];
4669+ [buf release ];
4670+
4671+ ctx->heap .bufs [j] = nil ;
4672+ }
4673+
4674+ ctx->heap .n = 0 ;
4675+ }
4676+
46384677 MTLCommandBufferStatus status = [command_buffer status ];
46394678 if (status != MTLCommandBufferStatusCompleted ) {
46404679 GGML_LOG_INFO (" %s : command buffer %d failed with status %lu \n " , __func__, i, status);
@@ -5046,7 +5085,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
50465085 [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
50475086 }
50485087
5049- ggml_metal_encode_node (backend, idx, encoder, ctx->heap );
5088+ ggml_metal_encode_node (backend, idx, encoder, & ctx->heap );
50505089
50515090 if (should_capture) {
50525091 [encoder popDebugGroup ];
0 commit comments