@@ -471,17 +471,55 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
471471 GGML_METAL_KERNEL_TYPE_COUNT
472472};
473473
474- // TODO: use MTLHeapTypePlacement and reset offset after every node
475474struct ggml_metal_heap {
476475 int n;
477476 int fail;
478477
479478 size_t need;
480479
480+ id <MTLDevice > device;
481481 id <MTLHeap > obj;
482482 id <MTLBuffer > bufs[GGML_METAL_MAX_HEAP_BUFFERS];
483483};
484484
485+ static struct ggml_metal_heap * ggml_metal_heap_init (id <MTLDevice > device, size_t size) {
486+ struct ggml_metal_heap * heap = calloc (1 , sizeof (struct ggml_metal_heap));
487+
488+ MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc ] init ];
489+ desc.storageMode = MTLStorageModePrivate ;
490+ desc.cpuCacheMode = MTLCPUCacheModeDefaultCache ;
491+ desc.type = MTLHeapTypeAutomatic ; // TODO: use MTLHeapTypePlacement
492+ desc.size = size;
493+
494+ heap->device = device;
495+ heap->obj = [device newHeapWithDescriptor: desc];
496+ if (!heap->obj ) {
497+ GGML_LOG_ERROR (" %s : error: failed to create MTLHeap with size %zu \n " , __func__, size);
498+
499+ free (heap);
500+
501+ return false ;
502+ }
503+
504+ for (int i = 0 ; i < GGML_METAL_MAX_HEAP_BUFFERS; ++i) {
505+ heap->bufs [i] = nil ;
506+ }
507+
508+ [desc release ];
509+
510+ return heap;
511+ }
512+
513+ static void ggml_metal_heap_free (struct ggml_metal_heap * heap) {
514+ if (heap == nil ) {
515+ return ;
516+ }
517+
518+ [heap->obj release ];
519+
520+ free (heap);
521+ }
522+
485523static void ggml_metal_heap_reset (struct ggml_metal_heap * heap) {
486524 heap->n = 0 ;
487525 heap->fail = 0 ;
@@ -498,6 +536,33 @@ static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
498536 }
499537}
500538
539+ static bool ggml_metal_heap_resize (struct ggml_metal_heap * heap, size_t size) {
540+ if (heap == nil ) {
541+ return false ;
542+ }
543+
544+ [heap->obj release ];
545+
546+ MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc ] init ];
547+ desc.storageMode = MTLStorageModePrivate ;
548+ desc.cpuCacheMode = MTLCPUCacheModeDefaultCache ;
549+ desc.type = MTLHeapTypeAutomatic ; // TODO: use MTLHeapTypePlacement
550+ desc.size = size;
551+
552+ heap->obj = [heap->device newHeapWithDescriptor: desc];
553+ if (!heap->obj ) {
554+ GGML_LOG_ERROR (" %s : error: failed to create MTLHeap with size %zu \n " , __func__, size);
555+
556+ return false ;
557+ }
558+
559+ [desc release ];
560+
561+ ggml_metal_heap_reset (heap);
562+
563+ return true ;
564+ }
565+
501566static id <MTLBuffer > ggml_metal_heap_alloc (struct ggml_metal_heap * heap, size_t size, size_t alignment) {
502567 const size_t size_aligned = GGML_PAD (size, alignment);
503568
@@ -531,7 +596,7 @@ static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
531596 id <MTLCommandQueue > queue;
532597
533598 // TODO: create heap per command buffer
534- struct ggml_metal_heap heap;
599+ struct ggml_metal_heap * heap;
535600
536601 dispatch_queue_t d_queue;
537602
@@ -757,24 +822,7 @@ @implementation GGMLMetalClass
757822
758823 ctx->d_queue = dispatch_queue_create (" ggml-metal" , DISPATCH_QUEUE_CONCURRENT);
759824
760- // allocate tmp heap with fixed size for testing
761- // TODO: factor into a function
762- {
763- MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc ] init ];
764- desc.storageMode = MTLStorageModePrivate ;
765- desc.cpuCacheMode = MTLCPUCacheModeDefaultCache ;
766- desc.type = MTLHeapTypeAutomatic ; // TODO: use MTLHeapTypePlacement
767- desc.size = 1024 *1024 ;
768-
769- ctx->heap .n = 0 ;
770-
771- ctx->heap .obj = [device newHeapWithDescriptor: desc];
772- for (int i = 0 ; i < GGML_METAL_MAX_HEAP_BUFFERS; ++i) {
773- ctx->heap .bufs [i] = nil ;
774- }
775-
776- [desc release ];
777- }
825+ ctx->heap = ggml_metal_heap_init (device, 1024 *1024 );
778826
779827 // load library
780828 if (ctx_dev->mtl_library == nil ) {
@@ -1218,8 +1266,9 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
12181266
12191267 Block_release (ctx->encode_async );
12201268
1221- [ctx->queue release ];
1222- [ctx->heap.obj release ];
1269+ [ctx->queue release ];
1270+
1271+ ggml_metal_heap_free (ctx->heap );
12231272
12241273 dispatch_release (ctx->d_queue );
12251274
@@ -2216,12 +2265,6 @@ static void ggml_metal_encode_node(
22162265 /* .nb3 =*/ nb03,
22172266 };
22182267
2219- // id<MTLBuffer> id_src0h = [heap->obj newBufferWithLength:ggml_nbytes(src0) options:MTLResourceStorageModePrivate];
2220-
2221- // // save a reference to the heap-allocated buffer
2222- // // TODO: simplify and check for available resources
2223- // heap->bufs[heap->n++] = id_src0h;
2224-
22252268 id <MTLBuffer > id_src0h = ggml_metal_heap_alloc (heap, ggml_nbytes (src0), 32 );
22262269 if (!id_src0h) {
22272270 break ;
@@ -4716,7 +4759,7 @@ static enum ggml_status ggml_metal_graph_compute(
47164759 [command_buffer waitUntilCompleted ];
47174760
47184761 // TODO: per command buffer heap
4719- ggml_metal_heap_reset (& ctx->heap );
4762+ ggml_metal_heap_reset (ctx->heap );
47204763
47214764 MTLCommandBufferStatus status = [command_buffer status ];
47224765 if (status != MTLCommandBufferStatusCompleted ) {
@@ -5133,7 +5176,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51335176 [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
51345177 }
51355178
5136- ggml_metal_encode_node (backend, idx, encoder, & ctx->heap );
5179+ ggml_metal_encode_node (backend, idx, encoder, ctx->heap );
51375180
51385181 if (should_capture) {
51395182 [encoder popDebugGroup ];
@@ -5142,28 +5185,18 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51425185
51435186 [encoder endEncoding ];
51445187
5145- if (ctx->heap . fail == 0 ) {
5188+ if (ctx->heap -> fail == 0 ) {
51465189 break ;
51475190 }
51485191
5149- // increase heap size
5150- [ctx->heap.obj release ];
5192+ const size_t need = ctx->heap ->need ;
51515193
5152- {
5153- MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc ] init ];
5154- desc.storageMode = MTLStorageModePrivate ;
5155- desc.cpuCacheMode = MTLCPUCacheModeDefaultCache ;
5156- desc.type = MTLHeapTypeAutomatic ; // TODO: use MTLHeapTypePlacement
5157- desc.size = ctx->heap .need ;
5158-
5159- GGML_LOG_INFO (" %s : increasing heap size to %zu \n " , __func__, ctx->heap .need );
5160-
5161- ctx->heap .obj = [ctx->device newHeapWithDescriptor: desc];
5194+ GGML_LOG_INFO (" %s : increasing heap size to %zu \n " , __func__, need);
51625195
5163- [desc release ];
5196+ if (!ggml_metal_heap_resize (ctx->heap , need)) {
5197+ GGML_LOG_ERROR (" %s : failed to increase heap size to %zu \n " , __func__, need);
5198+ break ;
51645199 }
5165-
5166- ggml_metal_heap_reset (&ctx->heap );
51675200 }
51685201
51695202 if (cb_idx < 2 || ctx->abort_callback == NULL ) {
0 commit comments