Skip to content

Commit 9248aec

Browse files
committed
metal : simplify synchronization logic
ggml-ci
1 parent 523750a commit 9248aec

File tree

1 file changed

+19
-39
lines changed

1 file changed

+19
-39
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -828,8 +828,8 @@ static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
828828
// extra command buffers for things like getting, setting and copying tensors
829829
NSMutableArray * cmd_bufs_ext;
830830

831+
// the last command buffer queued into the Metal queue with operations relevant to the current Metal backend
831832
id<MTLCommandBuffer> cmd_buf_last;
832-
id<MTLCommandBuffer> cmd_buf_ext_last;
833833

834834
// abort ggml_metal_graph_compute if callback returns true
835835
ggml_abort_callback abort_callback;
@@ -1110,7 +1110,6 @@ @implementation GGMLMetalClass
11101110
ctx->cmd_bufs_ext = [[NSMutableArray alloc] init];
11111111

11121112
ctx->cmd_buf_last = nil;
1113-
ctx->cmd_buf_ext_last = nil;
11141113

11151114
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
11161115
if (@available(macOS 10.12, iOS 16.0, *)) {
@@ -5735,6 +5734,12 @@ static enum ggml_status ggml_metal_graph_compute(
57355734
if (should_capture) {
57365735
ctx->capture_next_compute = false;
57375736

5737+
// make sure all previous computations have finished before starting the capture
5738+
if (ctx->cmd_buf_last) {
5739+
[ctx->cmd_buf_last waitUntilCompleted];
5740+
ctx->cmd_buf_last = nil;
5741+
}
5742+
57385743
if (!ctx->capture_started) {
57395744
// create capture scope
57405745
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
@@ -5757,16 +5762,11 @@ static enum ggml_status ggml_metal_graph_compute(
57575762
// the main thread commits the first few commands immediately
57585763
// cmd_buf[n_cb]
57595764
{
5760-
// first wait for any previous command buffer to be completed
5761-
// note: this checks only yhat the first part of the previous graph has been computed
5762-
// the rest of the graph might still be computing, but it is Ok to start queuing the beginning of the
5763-
/// new graph
5764-
if (ctx->cmd_bufs[n_cb].obj) {
5765-
[ctx->cmd_bufs[n_cb].obj waitUntilCompleted];
5766-
[ctx->cmd_bufs[n_cb].obj release];
5767-
}
5768-
5769-
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5765+
// cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed
5766+
// TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences
5767+
// https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009
5768+
//id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5769+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
57705770
[cmd_buf retain];
57715771

57725772
ctx->cmd_bufs[n_cb].obj = cmd_buf;
@@ -5776,23 +5776,14 @@ static enum ggml_status ggml_metal_graph_compute(
57765776
ctx->encode_async(n_cb);
57775777
}
57785778

5779-
// here we guarantee the full previous graph has finished computing
5780-
// but note that we have already enqueued the first part of the new graph so it can start processing, while
5781-
// continue to encode the rest of the graph
5782-
// TODO: remove these waits after we remove the memory pools
5783-
// https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009
5784-
if (ctx->cmd_buf_last) {
5785-
[ctx->cmd_buf_last waitUntilCompleted];
5786-
ctx->cmd_buf_last = nil;
5787-
}
5788-
57895779
// remember the command buffer for the next iteration
57905780
ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj;
57915781

57925782
// prepare the rest of the command buffers asynchronously (optional)
57935783
// cmd_buf[0.. n_cb)
57945784
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
5795-
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5785+
//id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
5786+
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
57965787
[cmd_buf retain];
57975788

57985789
if (ctx->cmd_bufs[cb_idx].obj) {
@@ -5868,11 +5859,6 @@ static enum ggml_status ggml_metal_graph_compute(
58685859
[next_buffer commit];
58695860
}
58705861

5871-
if (ctx->cmd_buf_last) {
5872-
[ctx->cmd_buf_last waitUntilCompleted];
5873-
ctx->cmd_buf_last = nil;
5874-
}
5875-
58765862
[ctx->capture_scope endScope];
58775863
[[MTLCaptureManager sharedCaptureManager] stopCapture];
58785864
}
@@ -6017,7 +6003,7 @@ static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, c
60176003
deallocator:nil];
60186004

60196005
id<MTLCommandQueue> queue = ctx->queue;
6020-
id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
6006+
id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
60216007

60226008
{
60236009
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
@@ -6062,7 +6048,7 @@ static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_
60626048
} else {
60636049
@autoreleasepool {
60646050
id<MTLCommandQueue> queue = ctx->queue;
6065-
id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
6051+
id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
60666052

60676053
{
60686054
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
@@ -6350,18 +6336,12 @@ static void ggml_backend_metal_free(ggml_backend_t backend) {
63506336
static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
63516337
struct ggml_backend_metal_context * ctx = backend->context;
63526338

6353-
// wait for the computation of the graph to finish
6339+
// wait for any backend operations to finish
63546340
if (ctx->cmd_buf_last) {
63556341
[ctx->cmd_buf_last waitUntilCompleted];
63566342
ctx->cmd_buf_last = nil;
63576343
}
63586344

6359-
// wait for any pending async get/set operations
6360-
if (ctx->cmd_buf_ext_last) {
6361-
[ctx->cmd_buf_ext_last waitUntilCompleted];
6362-
ctx->cmd_buf_ext_last = nil;
6363-
}
6364-
63656345
// release any completed command buffers
63666346
if (ctx->cmd_bufs_ext.count > 0) {
63676347
for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) {
@@ -6423,7 +6403,7 @@ static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, st
64236403

64246404
// instead, remember a reference to the command buffer and wait for it later if needed
64256405
[ctx->cmd_bufs_ext addObject:cmd_buf];
6426-
ctx->cmd_buf_ext_last = cmd_buf;
6406+
ctx->cmd_buf_last = cmd_buf;
64276407

64286408
[cmd_buf retain];
64296409
}
@@ -6469,7 +6449,7 @@ static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const st
64696449

64706450
// instead, remember a reference to the command buffer and wait for it later if needed
64716451
[ctx->cmd_bufs_ext addObject:cmd_buf];
6472-
ctx->cmd_buf_ext_last = cmd_buf;
6452+
ctx->cmd_buf_last = cmd_buf;
64736453

64746454
[cmd_buf retain];
64756455
}

0 commit comments

Comments
 (0)