Skip to content

Commit 88d496e

Browse files
committed
cont : heap for each cmd buffer [no ci]
1 parent 0e1f5aa commit 88d496e

File tree

1 file changed

+45
-35
lines changed

1 file changed

+45
-35
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -591,13 +591,16 @@ static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) {
591591
return buf;
592592
}
593593

594+
struct ggml_metal_command_buffer {
595+
id<MTLCommandBuffer> obj;
596+
597+
struct ggml_metal_heap * heap;
598+
};
599+
594600
struct ggml_backend_metal_context {
595601
id<MTLDevice> device;
596602
id<MTLCommandQueue> queue;
597603

598-
// TODO: create heap per command buffer
599-
struct ggml_metal_heap * heap;
600-
601604
dispatch_queue_t d_queue;
602605

603606
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
@@ -620,7 +623,8 @@ static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) {
620623
void (^encode_async)(size_t ith);
621624

622625
// n_cb command buffers + 1 used by the main thread
623-
id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
626+
//id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
627+
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
624628

625629
// abort ggml_metal_graph_compute if callback returns true
626630
ggml_abort_callback abort_callback;
@@ -822,8 +826,6 @@ @implementation GGMLMetalClass
822826

823827
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
824828

825-
ctx->heap = ggml_metal_heap_init(device, 1024*1024);
826-
827829
// load library
828830
if (ctx_dev->mtl_library == nil) {
829831
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
@@ -877,7 +879,11 @@ @implementation GGMLMetalClass
877879
ctx->gf = nil;
878880
ctx->encode_async = nil;
879881
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
880-
ctx->command_buffers[i] = nil;
882+
ctx->cmd_bufs[i].obj = nil;
883+
884+
// create 1MB heaps per command buffer
885+
// these can be resized during compute when necessary
886+
ctx->cmd_bufs[i].heap = ggml_metal_heap_init(device, 1024*1024);
881887
}
882888

883889
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
@@ -1268,7 +1274,11 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
12681274

12691275
[ctx->queue release];
12701276

1271-
ggml_metal_heap_free(ctx->heap);
1277+
//ggml_metal_heap_free(ctx->heap);
1278+
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
1279+
[ctx->cmd_bufs[i].obj release];
1280+
ggml_metal_heap_free(ctx->cmd_bufs[i].heap);
1281+
}
12721282

12731283
dispatch_release(ctx->d_queue);
12741284

@@ -4711,25 +4721,25 @@ static enum ggml_status ggml_metal_graph_compute(
47114721
}
47124722

47134723
// the main thread commits the first few commands immediately
4714-
// command_buffer[n_cb]
4724+
// cmd_buf[n_cb]
47154725
{
4716-
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
4717-
ctx->command_buffers[n_cb] = command_buffer;
4726+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
4727+
ctx->cmd_bufs[n_cb].obj = cmd_buf;
47184728

4719-
[command_buffer enqueue];
4729+
[cmd_buf enqueue];
47204730
ctx->encode_async(n_cb);
47214731
}
47224732

47234733
// prepare the rest of the command buffers asynchronously
4724-
// command_buffer[0.. n_cb)
4734+
// cmd_buf[0.. n_cb)
47254735
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
4726-
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
4727-
ctx->command_buffers[cb_idx] = command_buffer;
4736+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
4737+
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
47284738

47294739
// always enqueue the first two command buffers
47304740
// enqueue all of the command buffers if we don't need to abort
47314741
if (cb_idx < 2 || ctx->abort_callback == NULL) {
4732-
[command_buffer enqueue];
4742+
[cmd_buf enqueue];
47334743
}
47344744
}
47354745

@@ -4738,40 +4748,39 @@ static enum ggml_status ggml_metal_graph_compute(
47384748
// wait for completion and check status of each command buffer
47394749
// needed to detect if the device ran out-of-memory for example (#1881)
47404750
{
4741-
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
4742-
[command_buffer waitUntilCompleted];
4751+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
4752+
[cmd_buf waitUntilCompleted];
47434753

4744-
// TODO: free main cb heap
4754+
ggml_metal_heap_reset(ctx->cmd_bufs[n_cb].heap);
47454755

4746-
MTLCommandBufferStatus status = [command_buffer status];
4756+
MTLCommandBufferStatus status = [cmd_buf status];
47474757
if (status != MTLCommandBufferStatusCompleted) {
47484758
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
47494759
if (status == MTLCommandBufferStatusError) {
4750-
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
4760+
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
47514761
}
47524762

47534763
return GGML_STATUS_FAILED;
47544764
}
47554765
}
47564766

47574767
for (int i = 0; i < n_cb; ++i) {
4758-
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
4759-
[command_buffer waitUntilCompleted];
4768+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
4769+
[cmd_buf waitUntilCompleted];
47604770

4761-
// TODO: per command buffer heap
4762-
ggml_metal_heap_reset(ctx->heap);
4771+
ggml_metal_heap_reset(ctx->cmd_bufs[i].heap);
47634772

4764-
MTLCommandBufferStatus status = [command_buffer status];
4773+
MTLCommandBufferStatus status = [cmd_buf status];
47654774
if (status != MTLCommandBufferStatusCompleted) {
47664775
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
47674776
if (status == MTLCommandBufferStatusError) {
4768-
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
4777+
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
47694778
}
47704779

47714780
return GGML_STATUS_FAILED;
47724781
}
47734782

4774-
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
4783+
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
47754784
if (!next_buffer) {
47764785
continue;
47774786
}
@@ -5154,12 +5163,13 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51545163

51555164
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
51565165

5157-
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
5166+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
5167+
struct ggml_metal_heap * heap = ctx->cmd_bufs[cb_idx].heap;
51585168

51595169
int n_try = 3;
51605170

51615171
while (n_try-- > 0) {
5162-
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
5172+
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
51635173

51645174
int node_start = 0;
51655175
int node_end = n_nodes_0;
@@ -5176,7 +5186,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51765186
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
51775187
}
51785188

5179-
ggml_metal_encode_node(backend, idx, encoder, ctx->heap);
5189+
ggml_metal_encode_node(backend, idx, encoder, heap);
51805190

51815191
if (should_capture) {
51825192
[encoder popDebugGroup];
@@ -5185,22 +5195,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
51855195

51865196
[encoder endEncoding];
51875197

5188-
if (ctx->heap->fail == 0) {
5198+
if (heap->fail == 0) {
51895199
break;
51905200
}
51915201

5192-
const size_t need = ctx->heap->need;
5202+
const size_t need = heap->need;
51935203

51945204
GGML_LOG_INFO("%s: increasing heap size to %zu\n", __func__, need);
51955205

5196-
if (!ggml_metal_heap_resize(ctx->heap, need)) {
5206+
if (!ggml_metal_heap_resize(heap, need)) {
51975207
GGML_LOG_ERROR("%s: failed to increase heap size to %zu\n", __func__, need);
51985208
break;
51995209
}
52005210
}
52015211

52025212
if (cb_idx < 2 || ctx->abort_callback == NULL) {
5203-
[command_buffer commit];
5213+
[cmd_buf commit];
52045214
}
52055215
});
52065216
}

0 commit comments

Comments
 (0)