Skip to content

Commit 1a5970b

Browse files
committed
cont : heap for each cmd buffer [no ci]
1 parent 404fe19 commit 1a5970b

File tree

1 file changed

+45
-35
lines changed

1 file changed

+45
-35
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -591,13 +591,16 @@ static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) {
591591
return buf;
592592
}
593593

594+
struct ggml_metal_command_buffer {
595+
id<MTLCommandBuffer> obj;
596+
597+
struct ggml_metal_heap * heap;
598+
};
599+
594600
struct ggml_backend_metal_context {
595601
id<MTLDevice> device;
596602
id<MTLCommandQueue> queue;
597603

598-
// TODO: create heap per command buffer
599-
struct ggml_metal_heap * heap;
600-
601604
dispatch_queue_t d_queue;
602605

603606
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
@@ -620,7 +623,8 @@ static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) {
620623
void (^encode_async)(size_t ith);
621624

622625
// n_cb command buffers + 1 used by the main thread
623-
id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
626+
//id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
627+
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
624628

625629
// abort ggml_metal_graph_compute if callback returns true
626630
ggml_abort_callback abort_callback;
@@ -822,8 +826,6 @@ @implementation GGMLMetalClass
822826

823827
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
824828

825-
ctx->heap = ggml_metal_heap_init(device, 1024*1024);
826-
827829
// load library
828830
if (ctx_dev->mtl_library == nil) {
829831
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
@@ -877,7 +879,11 @@ @implementation GGMLMetalClass
877879
ctx->gf = nil;
878880
ctx->encode_async = nil;
879881
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
880-
ctx->command_buffers[i] = nil;
882+
ctx->cmd_bufs[i].obj = nil;
883+
884+
// create 1MB heaps per command buffer
885+
// these can be resized during compute when necessary
886+
ctx->cmd_bufs[i].heap = ggml_metal_heap_init(device, 1024*1024);
881887
}
882888

883889
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
@@ -1268,7 +1274,11 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
12681274

12691275
[ctx->queue release];
12701276

1271-
ggml_metal_heap_free(ctx->heap);
1277+
//ggml_metal_heap_free(ctx->heap);
1278+
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
1279+
[ctx->cmd_bufs[i].obj release];
1280+
ggml_metal_heap_free(ctx->cmd_bufs[i].heap);
1281+
}
12721282

12731283
dispatch_release(ctx->d_queue);
12741284

@@ -4712,25 +4722,25 @@ static enum ggml_status ggml_metal_graph_compute(
47124722
}
47134723

47144724
// the main thread commits the first few commands immediately
4715-
// command_buffer[n_cb]
4725+
// cmd_buf[n_cb]
47164726
{
4717-
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
4718-
ctx->command_buffers[n_cb] = command_buffer;
4727+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
4728+
ctx->cmd_bufs[n_cb].obj = cmd_buf;
47194729

4720-
[command_buffer enqueue];
4730+
[cmd_buf enqueue];
47214731
ctx->encode_async(n_cb);
47224732
}
47234733

47244734
// prepare the rest of the command buffers asynchronously
4725-
// command_buffer[0.. n_cb)
4735+
// cmd_buf[0.. n_cb)
47264736
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
4727-
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
4728-
ctx->command_buffers[cb_idx] = command_buffer;
4737+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
4738+
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
47294739

47304740
// always enqueue the first two command buffers
47314741
// enqueue all of the command buffers if we don't need to abort
47324742
if (cb_idx < 2 || ctx->abort_callback == NULL) {
4733-
[command_buffer enqueue];
4743+
[cmd_buf enqueue];
47344744
}
47354745
}
47364746

@@ -4739,40 +4749,39 @@ static enum ggml_status ggml_metal_graph_compute(
47394749
// wait for completion and check status of each command buffer
47404750
// needed to detect if the device ran out-of-memory for example (#1881)
47414751
{
4742-
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
4743-
[command_buffer waitUntilCompleted];
4752+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
4753+
[cmd_buf waitUntilCompleted];
47444754

4745-
// TODO: free main cb heap
4755+
ggml_metal_heap_reset(ctx->cmd_bufs[n_cb].heap);
47464756

4747-
MTLCommandBufferStatus status = [command_buffer status];
4757+
MTLCommandBufferStatus status = [cmd_buf status];
47484758
if (status != MTLCommandBufferStatusCompleted) {
47494759
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
47504760
if (status == MTLCommandBufferStatusError) {
4751-
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
4761+
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
47524762
}
47534763

47544764
return GGML_STATUS_FAILED;
47554765
}
47564766
}
47574767

47584768
for (int i = 0; i < n_cb; ++i) {
4759-
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
4760-
[command_buffer waitUntilCompleted];
4769+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
4770+
[cmd_buf waitUntilCompleted];
47614771

4762-
// TODO: per command buffer heap
4763-
ggml_metal_heap_reset(ctx->heap);
4772+
ggml_metal_heap_reset(ctx->cmd_bufs[i].heap);
47644773

4765-
MTLCommandBufferStatus status = [command_buffer status];
4774+
MTLCommandBufferStatus status = [cmd_buf status];
47664775
if (status != MTLCommandBufferStatusCompleted) {
47674776
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
47684777
if (status == MTLCommandBufferStatusError) {
4769-
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
4778+
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
47704779
}
47714780

47724781
return GGML_STATUS_FAILED;
47734782
}
47744783

4775-
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
4784+
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
47764785
if (!next_buffer) {
47774786
continue;
47784787
}
@@ -5155,12 +5164,13 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51555164

51565165
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
51575166

5158-
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
5167+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
5168+
struct ggml_metal_heap * heap = ctx->cmd_bufs[cb_idx].heap;
51595169

51605170
int n_try = 3;
51615171

51625172
while (n_try-- > 0) {
5163-
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
5173+
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
51645174

51655175
int node_start = 0;
51665176
int node_end = n_nodes_0;
@@ -5177,7 +5187,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51775187
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
51785188
}
51795189

5180-
ggml_metal_encode_node(backend, idx, encoder, ctx->heap);
5190+
ggml_metal_encode_node(backend, idx, encoder, heap);
51815191

51825192
if (should_capture) {
51835193
[encoder popDebugGroup];
@@ -5186,22 +5196,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51865196

51875197
[encoder endEncoding];
51885198

5189-
if (ctx->heap->fail == 0) {
5199+
if (heap->fail == 0) {
51905200
break;
51915201
}
51925202

5193-
const size_t need = ctx->heap->need;
5203+
const size_t need = heap->need;
51945204

51955205
GGML_LOG_INFO("%s: increasing heap size to %zu\n", __func__, need);
51965206

5197-
if (!ggml_metal_heap_resize(ctx->heap, need)) {
5207+
if (!ggml_metal_heap_resize(heap, need)) {
51985208
GGML_LOG_ERROR("%s: failed to increase heap size to %zu\n", __func__, need);
51995209
break;
52005210
}
52015211
}
52025212

52035213
if (cb_idx < 2 || ctx->abort_callback == NULL) {
5204-
[command_buffer commit];
5214+
[cmd_buf commit];
52055215
}
52065216
});
52075217
}

0 commit comments

Comments
 (0)