Skip to content

Commit bf8b390

Browse files
committed
metal : reuse graphs
ggml-ci
1 parent 0d2038f commit bf8b390

File tree

1 file changed

+205
-53
lines changed

1 file changed

+205
-53
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 205 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -821,13 +821,23 @@ static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
821821

822822
// the callback given to the thread pool
823823
void (^encode_async)(size_t ith);
824+
void (^encode_next)(void);
824825

825826
// n_cb command buffers + 1 used by the main thread
826827
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
828+
struct ggml_metal_command_buffer cmd_bufs_next[2];
827829

828830
// abort ggml_metal_graph_compute if callback returns true
829831
ggml_abort_callback abort_callback;
830832
void * abort_callback_data;
833+
834+
// reuse info
835+
int i_next;
836+
837+
int n_nodes_max;
838+
int n_nodes_prev;
839+
840+
struct ggml_tensor * cg_nodes;
831841
};
832842

833843
// MSL code
@@ -1084,13 +1094,21 @@ @implementation GGMLMetalClass
10841094

10851095
ctx->gf = nil;
10861096
ctx->encode_async = nil;
1097+
ctx->encode_next = nil;
10871098
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
10881099
ctx->cmd_bufs[i].obj = nil;
10891100

10901101
ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
10911102
ctx->cmd_bufs[i].mem_pool->device = device;
10921103
}
10931104

1105+
for (int i = 0; i < 2; ++i) {
1106+
ctx->cmd_bufs_next[i].obj = nil;
1107+
1108+
ctx->cmd_bufs_next[i].mem_pool = ggml_metal_mem_pool_init();
1109+
ctx->cmd_bufs_next[i].mem_pool->device = device;
1110+
}
1111+
10941112
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
10951113
if (@available(macOS 10.12, iOS 16.0, *)) {
10961114
GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
@@ -1521,6 +1539,13 @@ @implementation GGMLMetalClass
15211539
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
15221540
}
15231541

1542+
ctx->i_next = 0;
1543+
1544+
ctx->n_nodes_max = 16384;
1545+
ctx->n_nodes_prev = -1;
1546+
1547+
ctx->cg_nodes = ggml_aligned_malloc(ctx->n_nodes_max * sizeof(struct ggml_tensor));
1548+
15241549
return ctx;
15251550
}
15261551

@@ -1532,6 +1557,7 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
15321557
}
15331558

15341559
Block_release(ctx->encode_async);
1560+
Block_release(ctx->encode_next);
15351561

15361562
[ctx->queue release];
15371563

@@ -1541,8 +1567,13 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
15411567
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
15421568
}
15431569

1570+
ggml_metal_mem_pool_free(ctx->cmd_bufs_next[0].mem_pool);
1571+
ggml_metal_mem_pool_free(ctx->cmd_bufs_next[1].mem_pool);
1572+
15441573
dispatch_release(ctx->d_queue);
15451574

1575+
ggml_aligned_free(ctx->cg_nodes, ctx->n_nodes_max * sizeof(struct ggml_tensor));
1576+
15461577
free(ctx);
15471578
}
15481579

@@ -5448,6 +5479,39 @@ static enum ggml_status ggml_metal_graph_compute(
54485479
struct ggml_backend_metal_context * ctx = backend->context;
54495480
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
54505481

5482+
//const int64_t t_start = ggml_time_us();
5483+
5484+
/////////////////////////////////////////////////////
5485+
// hacky way to determine that the graph is the same as the previous one
5486+
//
5487+
bool can_reuse = true;
5488+
5489+
if (gf->n_nodes > ctx->n_nodes_max) {
5490+
can_reuse = false;
5491+
}
5492+
5493+
if (gf->n_nodes != ctx->n_nodes_prev) {
5494+
can_reuse = false;
5495+
}
5496+
5497+
if (can_reuse) {
5498+
for (int i = 0; i < gf->n_nodes; ++i) {
5499+
if (memcmp(gf->nodes[i], ctx->cg_nodes + i, sizeof(struct ggml_tensor)) != 0) {
5500+
can_reuse = false;
5501+
break;
5502+
}
5503+
}
5504+
}
5505+
5506+
if (!can_reuse) {
5507+
ctx->n_nodes_prev = gf->n_nodes;
5508+
5509+
for (int i = 0; i < gf->n_nodes; ++i) {
5510+
memcpy(ctx->cg_nodes + i, gf->nodes[i], sizeof(struct ggml_tensor));
5511+
}
5512+
}
5513+
//////////////////////////////////////////////////////
5514+
54515515
// number of nodes encoded by the main thread (empirically determined)
54525516
const int n_main = 128;
54535517

@@ -5492,78 +5556,126 @@ static enum ggml_status ggml_metal_graph_compute(
54925556
}
54935557
}
54945558

5495-
// the main thread commits the first few commands immediately
5496-
// cmd_buf[n_cb]
5497-
{
5498-
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5499-
ctx->cmd_bufs[n_cb].obj = cmd_buf;
5500-
5501-
[cmd_buf enqueue];
5502-
ctx->encode_async(n_cb);
5503-
}
5504-
5505-
// prepare the rest of the command buffers asynchronously
5506-
// cmd_buf[0.. n_cb)
5507-
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
5508-
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5509-
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
5559+
if (!can_reuse) {
5560+
// the main thread commits the first few commands immediately
5561+
// cmd_buf[n_cb]
5562+
{
5563+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5564+
ctx->cmd_bufs[n_cb].obj = cmd_buf;
55105565

5511-
// always enqueue the first two command buffers
5512-
// enqueue all of the command buffers if we don't need to abort
5513-
if (cb_idx < 2 || ctx->abort_callback == NULL) {
55145566
[cmd_buf enqueue];
5567+
ctx->encode_async(n_cb);
55155568
}
5516-
}
5517-
5518-
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
55195569

5520-
// wait for completion and check status of each command buffer
5521-
// needed to detect if the device ran out-of-memory for example (#1881)
5522-
{
5523-
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
5524-
[cmd_buf waitUntilCompleted];
5570+
// prepare the rest of the command buffers asynchronously
5571+
// cmd_buf[0.. n_cb)
5572+
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
5573+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5574+
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
55255575

5526-
MTLCommandBufferStatus status = [cmd_buf status];
5527-
if (status != MTLCommandBufferStatusCompleted) {
5528-
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
5529-
if (status == MTLCommandBufferStatusError) {
5530-
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
5576+
// always enqueue the first two command buffers
5577+
// enqueue all of the command buffers if we don't need to abort
5578+
if (cb_idx < 2 || ctx->abort_callback == NULL) {
5579+
[cmd_buf enqueue];
55315580
}
5532-
5533-
return GGML_STATUS_FAILED;
55345581
}
5535-
}
55365582

5537-
for (int i = 0; i < n_cb; ++i) {
5538-
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
5539-
[cmd_buf waitUntilCompleted];
5583+
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
55405584

5541-
MTLCommandBufferStatus status = [cmd_buf status];
5542-
if (status != MTLCommandBufferStatusCompleted) {
5543-
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
5544-
if (status == MTLCommandBufferStatusError) {
5545-
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
5585+
// encode the command buffer for the next iter while the GPU has already started
5586+
{
5587+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5588+
[cmd_buf retain];
5589+
if (ctx->cmd_bufs_next[ctx->i_next].obj != nil) {
5590+
[ctx->cmd_bufs_next[ctx->i_next].obj release];
55465591
}
5592+
ctx->cmd_bufs_next[ctx->i_next].obj = cmd_buf;
55475593

5548-
return GGML_STATUS_FAILED;
5594+
ctx->encode_next();
55495595
}
55505596

5551-
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
5552-
if (!next_buffer) {
5553-
continue;
5597+
// wait for completion and check status of each command buffer
5598+
// needed to detect if the device ran out-of-memory for example (#1881)
5599+
{
5600+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
5601+
[cmd_buf waitUntilCompleted];
5602+
5603+
MTLCommandBufferStatus status = [cmd_buf status];
5604+
if (status != MTLCommandBufferStatusCompleted) {
5605+
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
5606+
if (status == MTLCommandBufferStatusError) {
5607+
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
5608+
}
5609+
5610+
return GGML_STATUS_FAILED;
5611+
}
55545612
}
55555613

5556-
const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
5557-
if (next_queued) {
5558-
continue;
5614+
for (int i = 0; i < n_cb; ++i) {
5615+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
5616+
[cmd_buf waitUntilCompleted];
5617+
5618+
MTLCommandBufferStatus status = [cmd_buf status];
5619+
if (status != MTLCommandBufferStatusCompleted) {
5620+
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
5621+
if (status == MTLCommandBufferStatusError) {
5622+
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
5623+
}
5624+
5625+
return GGML_STATUS_FAILED;
5626+
}
5627+
5628+
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
5629+
if (!next_buffer) {
5630+
continue;
5631+
}
5632+
5633+
const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
5634+
if (next_queued) {
5635+
continue;
5636+
}
5637+
5638+
if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
5639+
GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
5640+
return GGML_STATUS_ABORTED;
5641+
}
5642+
5643+
[next_buffer commit];
55595644
}
5645+
} else {
5646+
struct ggml_metal_command_buffer cmd_buf_cur = ctx->cmd_bufs_next[(ctx->i_next + 1)%2];
5647+
5648+
// directly submit the command buffer that we have prepared in the previous iteration
5649+
[ctx->cmd_bufs_next[(ctx->i_next + 1)%2].obj commit];
55605650

5561-
if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
5562-
GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
5563-
return GGML_STATUS_ABORTED;
5651+
// encode the command buffer for the next iter
5652+
{
5653+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5654+
[cmd_buf retain];
5655+
if (ctx->cmd_bufs_next[ctx->i_next].obj != nil) {
5656+
[ctx->cmd_bufs_next[ctx->i_next].obj release];
5657+
}
5658+
ctx->cmd_bufs_next[ctx->i_next].obj = cmd_buf;
5659+
5660+
ctx->encode_next();
55645661
}
55655662

5566-
[next_buffer commit];
5663+
// wait for completion and check status of each command buffer
5664+
// needed to detect if the device ran out-of-memory for example (#1881)
5665+
{
5666+
id<MTLCommandBuffer> cmd_buf = cmd_buf_cur.obj;
5667+
[cmd_buf waitUntilCompleted];
5668+
5669+
MTLCommandBufferStatus status = [cmd_buf status];
5670+
if (status != MTLCommandBufferStatusCompleted) {
5671+
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, ctx->i_next, status);
5672+
if (status == MTLCommandBufferStatusError) {
5673+
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
5674+
}
5675+
5676+
return GGML_STATUS_FAILED;
5677+
}
5678+
}
55675679
}
55685680

55695681
if (!should_capture && ctx->capture_started) {
@@ -5572,6 +5684,8 @@ static enum ggml_status ggml_metal_graph_compute(
55725684
}
55735685
}
55745686

5687+
//printf(" time = %.3f ms\n", (float)(ggml_time_us() - t_start)/1000.0f);
5688+
55755689
return GGML_STATUS_SUCCESS;
55765690
}
55775691

@@ -5919,6 +6033,10 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
59196033
Block_release(ctx->encode_async);
59206034
}
59216035

6036+
if (ctx->encode_next) {
6037+
Block_release(ctx->encode_next);
6038+
}
6039+
59226040
ctx->encode_async = Block_copy(^(size_t iter) {
59236041
const int cb_idx = iter;
59246042
const int n_cb_l = ctx->n_cb;
@@ -5967,6 +6085,40 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
59676085
[cmd_buf commit];
59686086
}
59696087
});
6088+
6089+
ctx->encode_next = Block_copy(^(void) {
6090+
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs_next[ctx->i_next].obj;
6091+
6092+
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
6093+
6094+
int node_start = 0;
6095+
int node_end = ctx->gf->n_nodes;
6096+
6097+
const bool should_capture = ctx->capture_next_compute;
6098+
6099+
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs_next[ctx->i_next].mem_pool;
6100+
ggml_metal_mem_pool_reset(mem_pool);
6101+
6102+
for (int idx = node_start; idx < node_end; ++idx) {
6103+
if (should_capture) {
6104+
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
6105+
}
6106+
6107+
const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
6108+
6109+
if (should_capture) {
6110+
[encoder popDebugGroup];
6111+
}
6112+
6113+
if (!res) {
6114+
break;
6115+
}
6116+
}
6117+
6118+
[encoder endEncoding];
6119+
6120+
ctx->i_next = (ctx->i_next + 1) % 2;
6121+
});
59706122
}
59716123

59726124
static struct ggml_backend_i ggml_backend_metal_i = {

0 commit comments

Comments
 (0)