Skip to content

Commit aa37645

Browse files
committed
cont : resize heap [no ci]
1 parent 9155e94 commit aa37645

File tree

1 file changed

+118
-46
lines changed

1 file changed

+118
-46
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 118 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -471,18 +471,67 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
471471
GGML_METAL_KERNEL_TYPE_COUNT
472472
};
473473

474-
struct ggml_backend_metal_heap {
474+
// TODO: use MTLHeapTypePlacement and reset offset after every node
475+
struct ggml_metal_heap {
475476
int n;
477+
int fail;
478+
479+
size_t need;
476480

477481
id<MTLHeap> obj;
478482
id<MTLBuffer> bufs[GGML_METAL_MAX_HEAP_BUFFERS];
479483
};
480484

485+
static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
486+
heap->n = 0;
487+
heap->fail = 0;
488+
heap->need = 0;
489+
490+
for (int i = 0; i < GGML_METAL_MAX_HEAP_BUFFERS; i++) {
491+
if (heap->bufs[i]) {
492+
[heap->bufs[i] release];
493+
heap->bufs[i] = nil;
494+
continue;
495+
}
496+
497+
break;
498+
}
499+
}
500+
501+
static id<MTLBuffer> ggml_metal_heap_alloc(struct ggml_metal_heap * heap, size_t size, size_t alignment) {
502+
const size_t size_aligned = GGML_PAD(size, alignment);
503+
504+
heap->need += size_aligned;
505+
506+
if (!heap->fail && heap->need > [heap->obj maxAvailableSizeWithAlignment:alignment]) {
507+
heap->fail = 1;
508+
}
509+
510+
if (!heap->fail && heap->n >= GGML_METAL_MAX_HEAP_BUFFERS) {
511+
heap->fail = 2;
512+
}
513+
514+
if (heap->fail) {
515+
return nil;
516+
}
517+
518+
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
519+
if (!buf) {
520+
heap->fail = 3;
521+
return nil;
522+
}
523+
524+
heap->bufs[heap->n++] = buf;
525+
526+
return buf;
527+
}
528+
481529
struct ggml_backend_metal_context {
530+
id<MTLDevice> device;
482531
id<MTLCommandQueue> queue;
483532

484533
// TODO: create heap per command buffer
485-
struct ggml_backend_metal_heap heap;
534+
struct ggml_metal_heap heap;
486535

487536
dispatch_queue_t d_queue;
488537

@@ -696,9 +745,11 @@ @implementation GGMLMetalClass
696745
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
697746

698747
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
748+
699749
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
700750

701-
ctx->queue = [device newCommandQueue];
751+
ctx->device = device;
752+
ctx->queue = [device newCommandQueue];
702753
if (ctx->queue == nil) {
703754
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
704755
return NULL;
@@ -707,21 +758,22 @@ @implementation GGMLMetalClass
707758
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
708759

709760
// allocate tmp heap with fixed size for testing
710-
// TODO: figure out how to dynamically resize it
761+
// TODO: factor into a function
711762
{
712-
MTLHeapDescriptor *heapDescriptor = [[MTLHeapDescriptor alloc] init];
713-
heapDescriptor.storageMode = MTLStorageModePrivate;
714-
heapDescriptor.cpuCacheMode = MTLCPUCacheModeDefaultCache;
715-
heapDescriptor.size = 32*1024*1024;
763+
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
764+
desc.storageMode = MTLStorageModePrivate;
765+
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
766+
desc.type = MTLHeapTypeAutomatic; // TODO: use MTLHeapTypePlacement
767+
desc.size = 1024*1024;
716768

717769
ctx->heap.n = 0;
718770

719-
ctx->heap.obj = [device newHeapWithDescriptor:heapDescriptor];
771+
ctx->heap.obj = [device newHeapWithDescriptor:desc];
720772
for (int i = 0; i < GGML_METAL_MAX_HEAP_BUFFERS; ++i) {
721773
ctx->heap.bufs[i] = nil;
722774
}
723775

724-
[heapDescriptor release];
776+
[desc release];
725777
}
726778

727779
// load library
@@ -1471,7 +1523,7 @@ static void ggml_metal_encode_node(
14711523
ggml_backend_t backend,
14721524
int idx,
14731525
id<MTLComputeCommandEncoder> encoder,
1474-
struct ggml_backend_metal_heap * heap) {
1526+
struct ggml_metal_heap * heap) {
14751527
struct ggml_backend_metal_context * ctx = backend->context;
14761528
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
14771529

@@ -2164,11 +2216,16 @@ static void ggml_metal_encode_node(
21642216
/*.nb3 =*/ nb03,
21652217
};
21662218

2167-
id<MTLBuffer> id_src0h = [heap->obj newBufferWithLength:ggml_nbytes(src0) options:MTLResourceStorageModePrivate];
2219+
//id<MTLBuffer> id_src0h = [heap->obj newBufferWithLength:ggml_nbytes(src0) options:MTLResourceStorageModePrivate];
21682220

2169-
// save a reference to the heap-allocated buffer
2170-
// TODO: simplify and check for available resources
2171-
heap->bufs[heap->n++] = id_src0h;
2221+
//// save a reference to the heap-allocated buffer
2222+
//// TODO: simplify and check for available resources
2223+
//heap->bufs[heap->n++] = id_src0h;
2224+
2225+
id<MTLBuffer> id_src0h = ggml_metal_heap_alloc(heap, ggml_nbytes(src0), 32);
2226+
if (!id_src0h) {
2227+
break;
2228+
}
21722229

21732230
if (src0->type == GGML_TYPE_F16) {
21742231
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
@@ -4658,21 +4715,8 @@ static enum ggml_status ggml_metal_graph_compute(
46584715
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
46594716
[command_buffer waitUntilCompleted];
46604717

4661-
// free buffers from the heap
4662-
{
4663-
size_t size_allocated = [ctx->heap.obj currentAllocatedSize];
4664-
size_t size_used = [ctx->heap.obj usedSize];
4665-
GGML_LOG_INFO("%s: command buffer %d, allocated = %zu, used = %zu, n = %d\n", __func__, i, size_allocated, size_used, ctx->heap.n);
4666-
4667-
for (int j = 0; j < ctx->heap.n; ++j) {
4668-
id<MTLBuffer> buf = ctx->heap.bufs[j];
4669-
[buf release];
4670-
4671-
ctx->heap.bufs[j] = nil;
4672-
}
4673-
4674-
ctx->heap.n = 0;
4675-
}
4718+
// TODO: per command buffer heap
4719+
ggml_metal_heap_reset(&ctx->heap);
46764720

46774721
MTLCommandBufferStatus status = [command_buffer status];
46784722
if (status != MTLCommandBufferStatusCompleted) {
@@ -5068,31 +5112,59 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
50685112
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
50695113

50705114
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
5071-
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
50725115

5073-
int node_start = 0;
5074-
int node_end = n_nodes_0;
5116+
int n_try = 3;
50755117

5076-
if (cb_idx < n_cb_l) {
5077-
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
5078-
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
5079-
}
5118+
while (n_try-- > 0) {
5119+
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
50805120

5081-
const bool should_capture = ctx->capture_next_compute;
5121+
int node_start = 0;
5122+
int node_end = n_nodes_0;
50825123

5083-
for (int idx = node_start; idx < node_end; ++idx) {
5084-
if (should_capture) {
5085-
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
5124+
if (cb_idx < n_cb_l) {
5125+
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
5126+
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
50865127
}
50875128

5088-
ggml_metal_encode_node(backend, idx, encoder, &ctx->heap);
5129+
const bool should_capture = ctx->capture_next_compute;
5130+
5131+
for (int idx = node_start; idx < node_end; ++idx) {
5132+
if (should_capture) {
5133+
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
5134+
}
5135+
5136+
ggml_metal_encode_node(backend, idx, encoder, &ctx->heap);
50895137

5090-
if (should_capture) {
5091-
[encoder popDebugGroup];
5138+
if (should_capture) {
5139+
[encoder popDebugGroup];
5140+
}
50925141
}
5093-
}
50945142

5095-
[encoder endEncoding];
5143+
[encoder endEncoding];
5144+
5145+
if (ctx->heap.fail == 0) {
5146+
break;
5147+
}
5148+
5149+
// increase heap size
5150+
[ctx->heap.obj release];
5151+
5152+
{
5153+
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
5154+
desc.storageMode = MTLStorageModePrivate;
5155+
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
5156+
desc.type = MTLHeapTypeAutomatic; // TODO: use MTLHeapTypePlacement
5157+
desc.size = ctx->heap.need;
5158+
5159+
GGML_LOG_INFO("%s: increasing heap size to %zu\n", __func__, ctx->heap.need);
5160+
5161+
ctx->heap.obj = [ctx->device newHeapWithDescriptor:desc];
5162+
5163+
[desc release];
5164+
}
5165+
5166+
ggml_metal_heap_reset(&ctx->heap);
5167+
}
50965168

50975169
if (cb_idx < 2 || ctx->abort_callback == NULL) {
50985170
[command_buffer commit];

0 commit comments

Comments
 (0)