@@ -471,18 +471,67 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
471471 GGML_METAL_KERNEL_TYPE_COUNT
472472};
473473
474- struct ggml_backend_metal_heap {
474+ // TODO: use MTLHeapTypePlacement and reset offset after every node
475+ struct ggml_metal_heap {
475476 int n;
477+ int fail;
478+
479+ size_t need;
476480
477481 id <MTLHeap > obj;
478482 id <MTLBuffer > bufs[GGML_METAL_MAX_HEAP_BUFFERS];
479483};
480484
485+ static void ggml_metal_heap_reset (struct ggml_metal_heap * heap) {
486+ heap->n = 0 ;
487+ heap->fail = 0 ;
488+ heap->need = 0 ;
489+
490+ for (int i = 0 ; i < GGML_METAL_MAX_HEAP_BUFFERS; i++) {
491+ if (heap->bufs [i]) {
492+ [heap->bufs[i] release ];
493+ heap->bufs [i] = nil ;
494+ continue ;
495+ }
496+
497+ break ;
498+ }
499+ }
500+
501+ static id <MTLBuffer > ggml_metal_heap_alloc (struct ggml_metal_heap * heap, size_t size, size_t alignment) {
502+ const size_t size_aligned = GGML_PAD (size, alignment);
503+
504+ heap->need += size_aligned;
505+
506+ if (!heap->fail && heap->need > [heap->obj maxAvailableSizeWithAlignment: alignment]) {
507+ heap->fail = 1 ;
508+ }
509+
510+ if (!heap->fail && heap->n >= GGML_METAL_MAX_HEAP_BUFFERS) {
511+ heap->fail = 2 ;
512+ }
513+
514+ if (heap->fail ) {
515+ return nil ;
516+ }
517+
518+ id <MTLBuffer > buf = [heap->obj newBufferWithLength: size_aligned options: MTLResourceStorageModePrivate ];
519+ if (!buf) {
520+ heap->fail = 3 ;
521+ return nil ;
522+ }
523+
524+ heap->bufs [heap->n++] = buf;
525+
526+ return buf;
527+ }
528+
481529struct ggml_backend_metal_context {
530+ id <MTLDevice > device;
482531 id <MTLCommandQueue > queue;
483532
484533 // TODO: create heap per command buffer
485- struct ggml_backend_metal_heap heap;
534+ struct ggml_metal_heap heap;
486535
487536 dispatch_queue_t d_queue;
488537
@@ -696,9 +745,11 @@ @implementation GGMLMetalClass
696745 struct ggml_backend_metal_device_context * ctx_dev = dev->context ;
697746
698747 id <MTLDevice > device = ggml_backend_metal_device_acq (ctx_dev);
748+
699749 GGML_LOG_INFO (" %s : picking default device: %s \n " , __func__, [[device name ] UTF8String ]);
700750
701- ctx->queue = [device newCommandQueue ];
751+ ctx->device = device;
752+ ctx->queue = [device newCommandQueue ];
702753 if (ctx->queue == nil ) {
703754 GGML_LOG_ERROR (" %s : error: failed to create command queue\n " , __func__);
704755 return NULL ;
@@ -707,21 +758,22 @@ @implementation GGMLMetalClass
707758 ctx->d_queue = dispatch_queue_create (" ggml-metal" , DISPATCH_QUEUE_CONCURRENT);
708759
709760 // allocate tmp heap with fixed size for testing
710- // TODO: figure out how to dynamically resize it
761+ // TODO: factor into a function
711762 {
712- MTLHeapDescriptor *heapDescriptor = [[MTLHeapDescriptor alloc ] init ];
713- heapDescriptor.storageMode = MTLStorageModePrivate ;
714- heapDescriptor.cpuCacheMode = MTLCPUCacheModeDefaultCache ;
715- heapDescriptor.size = 32 *1024 *1024 ;
763+ MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc ] init ];
764+ desc.storageMode = MTLStorageModePrivate ;
765+ desc.cpuCacheMode = MTLCPUCacheModeDefaultCache ;
766+ desc.type = MTLHeapTypeAutomatic ; // TODO: use MTLHeapTypePlacement
767+ desc.size = 1024 *1024 ;
716768
717769 ctx->heap .n = 0 ;
718770
719- ctx->heap .obj = [device newHeapWithDescriptor: heapDescriptor ];
771+ ctx->heap .obj = [device newHeapWithDescriptor: desc ];
720772 for (int i = 0 ; i < GGML_METAL_MAX_HEAP_BUFFERS; ++i) {
721773 ctx->heap .bufs [i] = nil ;
722774 }
723775
724- [heapDescriptor release ];
776+ [desc release ];
725777 }
726778
727779 // load library
@@ -1471,7 +1523,7 @@ static void ggml_metal_encode_node(
14711523 ggml_backend_t backend,
14721524 int idx,
14731525 id <MTLComputeCommandEncoder > encoder,
1474- struct ggml_backend_metal_heap * heap) {
1526+ struct ggml_metal_heap * heap) {
14751527 struct ggml_backend_metal_context * ctx = backend->context ;
14761528 struct ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
14771529
@@ -2164,11 +2216,16 @@ static void ggml_metal_encode_node(
21642216 /* .nb3 =*/ nb03,
21652217 };
21662218
2167- id <MTLBuffer > id_src0h = [heap->obj newBufferWithLength: ggml_nbytes (src0) options: MTLResourceStorageModePrivate ];
2219+ // id<MTLBuffer> id_src0h = [heap->obj newBufferWithLength:ggml_nbytes(src0) options:MTLResourceStorageModePrivate];
21682220
2169- // save a reference to the heap-allocated buffer
2170- // TODO: simplify and check for available resources
2171- heap->bufs [heap->n++] = id_src0h;
2221+ // // save a reference to the heap-allocated buffer
2222+ // // TODO: simplify and check for available resources
2223+ // heap->bufs[heap->n++] = id_src0h;
2224+
2225+ id <MTLBuffer > id_src0h = ggml_metal_heap_alloc (heap, ggml_nbytes (src0), 32 );
2226+ if (!id_src0h) {
2227+ break ;
2228+ }
21722229
21732230 if (src0->type == GGML_TYPE_F16) {
21742231 [encoder setComputePipelineState: ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
@@ -4658,21 +4715,8 @@ static enum ggml_status ggml_metal_graph_compute(
46584715 id <MTLCommandBuffer > command_buffer = ctx->command_buffers [i];
46594716 [command_buffer waitUntilCompleted ];
46604717
4661- // free buffers from the heap
4662- {
4663- size_t size_allocated = [ctx->heap.obj currentAllocatedSize ];
4664- size_t size_used = [ctx->heap.obj usedSize ];
4665- GGML_LOG_INFO (" %s : command buffer %d , allocated = %zu , used = %zu , n = %d \n " , __func__, i, size_allocated, size_used, ctx->heap .n );
4666-
4667- for (int j = 0 ; j < ctx->heap .n ; ++j) {
4668- id <MTLBuffer > buf = ctx->heap .bufs [j];
4669- [buf release ];
4670-
4671- ctx->heap .bufs [j] = nil ;
4672- }
4673-
4674- ctx->heap .n = 0 ;
4675- }
4718+ // TODO: per command buffer heap
4719+ ggml_metal_heap_reset (&ctx->heap );
46764720
46774721 MTLCommandBufferStatus status = [command_buffer status ];
46784722 if (status != MTLCommandBufferStatusCompleted ) {
@@ -5068,31 +5112,59 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
50685112 const int n_nodes_per_cb = ctx->n_nodes_per_cb ;
50695113
50705114 id <MTLCommandBuffer > command_buffer = ctx->command_buffers [cb_idx];
5071- id <MTLComputeCommandEncoder > encoder = [command_buffer computeCommandEncoder ];
50725115
5073- int node_start = 0 ;
5074- int node_end = n_nodes_0;
5116+ int n_try = 3 ;
50755117
5076- if (cb_idx < n_cb_l) {
5077- node_start = n_nodes_0 + ( (cb_idx + 0 ) * n_nodes_per_cb);
5078- node_end = n_nodes_0 + (MIN ((cb_idx == n_cb_l - 1 ) ? n_nodes_1 : (cb_idx + 1 ) * n_nodes_per_cb, n_nodes_1));
5079- }
5118+ while (n_try-- > 0 ) {
5119+ id <MTLComputeCommandEncoder > encoder = [command_buffer computeCommandEncoder ];
50805120
5081- const bool should_capture = ctx->capture_next_compute ;
5121+ int node_start = 0 ;
5122+ int node_end = n_nodes_0;
50825123
5083- for ( int idx = node_start; idx < node_end; ++idx ) {
5084- if (should_capture) {
5085- [encoder pushDebugGroup: [ NSString stringWithCString: ggml_op_desc ( ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]] ;
5124+ if (cb_idx < n_cb_l ) {
5125+ node_start = n_nodes_0 + ( (cb_idx + 0 ) * n_nodes_per_cb);
5126+ node_end = n_nodes_0 + ( MIN ((cb_idx == n_cb_l - 1 ) ? n_nodes_1 : (cb_idx + 1 ) * n_nodes_per_cb, n_nodes_1)) ;
50865127 }
50875128
5088- ggml_metal_encode_node (backend, idx, encoder, &ctx->heap );
5129+ const bool should_capture = ctx->capture_next_compute ;
5130+
5131+ for (int idx = node_start; idx < node_end; ++idx) {
5132+ if (should_capture) {
5133+ [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
5134+ }
5135+
5136+ ggml_metal_encode_node (backend, idx, encoder, &ctx->heap );
50895137
5090- if (should_capture) {
5091- [encoder popDebugGroup ];
5138+ if (should_capture) {
5139+ [encoder popDebugGroup ];
5140+ }
50925141 }
5093- }
50945142
5095- [encoder endEncoding ];
5143+ [encoder endEncoding ];
5144+
5145+ if (ctx->heap .fail == 0 ) {
5146+ break ;
5147+ }
5148+
5149+ // increase heap size
5150+ [ctx->heap.obj release ];
5151+
5152+ {
5153+ MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc ] init ];
5154+ desc.storageMode = MTLStorageModePrivate ;
5155+ desc.cpuCacheMode = MTLCPUCacheModeDefaultCache ;
5156+ desc.type = MTLHeapTypeAutomatic ; // TODO: use MTLHeapTypePlacement
5157+ desc.size = ctx->heap .need ;
5158+
5159+ GGML_LOG_INFO (" %s : increasing heap size to %zu \n " , __func__, ctx->heap .need );
5160+
5161+ ctx->heap .obj = [ctx->device newHeapWithDescriptor: desc];
5162+
5163+ [desc release ];
5164+ }
5165+
5166+ ggml_metal_heap_reset (&ctx->heap );
5167+ }
50965168
50975169 if (cb_idx < 2 || ctx->abort_callback == NULL ) {
50985170 [command_buffer commit ];
0 commit comments