Skip to content

Commit 5273e59

Browse files
committed
metal : add comments
1 parent 43b9d69 commit 5273e59

File tree

1 file changed

+49
-46
lines changed

1 file changed

+49
-46
lines changed

ggml/src/ggml-metal.m

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
// max memory buffers that can be mapped to the device
1616
#define GGML_METAL_MAX_BUFFERS 64
1717

18-
#define GGML_METAL_MAX_COMMAND_BUFFERS 128
18+
// max number of MTLCommandBuffer used to submit a graph for processing
19+
#define GGML_METAL_MAX_COMMAND_BUFFERS 8
1920

2021
#ifdef GGML_METAL_NDEBUG
2122
#define GGML_METAL_LOG(...)
@@ -226,8 +227,6 @@
226227
};
227228

228229
struct ggml_backend_metal_context {
229-
int n_cb;
230-
231230
id<MTLDevice> device;
232231
id<MTLCommandQueue> queue;
233232

@@ -240,21 +239,28 @@
240239
bool support_simdgroup_reduction;
241240
bool support_simdgroup_mm;
242241

243-
bool should_capture_next_compute;
242+
// capture state
243+
bool capture_next_compute;
244244
bool capture_started;
245245

246-
id<MTLCaptureScope> cap_scope;
247-
248-
id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
246+
id<MTLCaptureScope> capture_scope;
249247

250-
int n_nodes_0;
251-
int n_nodes_1;
248+
// command buffer state
249+
int n_cb; // number of extra threads used to submit the command buffers
250+
int n_nodes_0; // number of nodes submitted by the main thread
251+
int n_nodes_1; // remaining number of nodes submitted by the n_cb threads
252252
int n_nodes_per_cb;
253253

254254
struct ggml_cgraph * gf;
255255

256+
// the callback given to the thread pool
257+
// TODO: ideally, this should be created once, utilizing the command buffer state above
258+
// for some reason, doing it like this leads to a crash
256259
void (^encode_async)(size_t ith);
257260

261+
// n_cb command buffers + 1 used by the main thread
262+
id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
263+
258264
// abort ggml_metal_graph_compute if callback returns true
259265
ggml_abort_callback abort_callback;
260266
void * abort_callback_data;
@@ -476,17 +482,16 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
476482
GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
477483
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
478484

479-
ctx->should_capture_next_compute = false;
485+
ctx->capture_next_compute = false;
480486
ctx->capture_started = false;
487+
ctx->capture_scope = nil;
481488

482-
ctx->cap_scope = nil;
483-
489+
ctx->gf = nil;
490+
ctx->encode_async = nil;
484491
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
485492
ctx->command_buffers[i] = nil;
486493
}
487494

488-
ctx->encode_async = nil;
489-
490495
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
491496
if (@available(macOS 10.12, iOS 16.0, *)) {
492497
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
@@ -3000,31 +3005,37 @@ static void ggml_metal_encode_node(
30003005
static enum ggml_status ggml_metal_graph_compute(
30013006
struct ggml_backend_metal_context * ctx,
30023007
struct ggml_cgraph * gf) {
3003-
@autoreleasepool {
3004-
// create multiple command buffers and enqueue them
3005-
// then, we encode the graph into the command buffers in parallel
3008+
// number of nodes encoded by the main thread (empirically determined)
3009+
const int n_main = 128;
30063010

3007-
const int n_cb = ctx->n_cb;
3011+
// number of threads in addition to the main thread
3012+
const int n_cb = ctx->n_cb;
30083013

3014+
// submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
3015+
// the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
3016+
// while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
3017+
// each thread creates it's own command buffer and enqueues the ops in parallel
3018+
//
3019+
// tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
3020+
3021+
@autoreleasepool {
30093022
ctx->gf = gf;
30103023

3011-
ctx->n_nodes_0 = MIN(128, gf->n_nodes);
3024+
ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
30123025
ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
30133026

3014-
ctx->n_nodes_per_cb = (ctx->n_nodes_1 + n_cb - 1) / n_cb;
3027+
ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
30153028

3016-
//const int64_t t_start = ggml_time_us();
3017-
3018-
const bool should_capture = ctx->should_capture_next_compute;
3029+
const bool should_capture = ctx->capture_next_compute;
30193030
if (should_capture) {
3020-
ctx->should_capture_next_compute = false;
3031+
ctx->capture_next_compute = false;
30213032

30223033
if (!ctx->capture_started) {
30233034
// create capture scope
3024-
ctx->cap_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
3035+
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
30253036

30263037
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
3027-
descriptor.captureObject = ctx->cap_scope;
3038+
descriptor.captureObject = ctx->capture_scope;
30283039
descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
30293040
descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
30303041

@@ -3033,7 +3044,7 @@ static enum ggml_status ggml_metal_graph_compute(
30333044
GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
30343045
GGML_ABORT("capture failed");
30353046
} else {
3036-
[ctx->cap_scope beginScope];
3047+
[ctx->capture_scope beginScope];
30373048
ctx->capture_started = true;
30383049
}
30393050
}
@@ -3055,7 +3066,7 @@ static enum ggml_status ggml_metal_graph_compute(
30553066
int node_start = 0;
30563067
int node_end = n_nodes_0;
30573068

3058-
if ((int) iter < n_cb_l) {
3069+
if (cb_idx < n_cb_l) {
30593070
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
30603071
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
30613072
}
@@ -3079,17 +3090,20 @@ static enum ggml_status ggml_metal_graph_compute(
30793090
}
30803091
};
30813092

3093+
// the main thread commits the first few commands immediately
3094+
// command_buffer[n_cb]
30823095
{
30833096
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
30843097
ctx->command_buffers[n_cb] = command_buffer;
30853098

30863099
[command_buffer enqueue];
3087-
30883100
ctx->encode_async(n_cb);
30893101
}
30903102

3103+
// prepare the rest of the command buffers asynchronously
3104+
// command_buffer[0.. n_cb)
30913105
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
3092-
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
3106+
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
30933107
ctx->command_buffers[cb_idx] = command_buffer;
30943108

30953109
// always enqueue the first two command buffers
@@ -3101,14 +3115,8 @@ static enum ggml_status ggml_metal_graph_compute(
31013115

31023116
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
31033117

3104-
//{
3105-
// const int64_t t_end = ggml_time_us();
3106-
// //printf("time to encode: %d us, n_cb = %d\n", (int) (t_end - t_start), n_cb);
3107-
//}
3108-
3109-
// Wait for completion and check status of each command buffer
3118+
// wait for completion and check status of each command buffer
31103119
// needed to detect if the device ran out-of-memory for example (#1881)
3111-
31123120
{
31133121
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
31143122
[command_buffer waitUntilCompleted];
@@ -3143,7 +3151,7 @@ static enum ggml_status ggml_metal_graph_compute(
31433151
continue;
31443152
}
31453153

3146-
bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
3154+
const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
31473155
if (next_queued) {
31483156
continue;
31493157
}
@@ -3156,13 +3164,8 @@ static enum ggml_status ggml_metal_graph_compute(
31563164
[next_buffer commit];
31573165
}
31583166

3159-
//{
3160-
// const int64_t t_end = ggml_time_us();
3161-
// printf("time to compute: %d us\n", (int)(t_end - t_start));
3162-
//}
3163-
31643167
if (!should_capture && ctx->capture_started) {
3165-
[ctx->cap_scope endScope];
3168+
[ctx->capture_scope endScope];
31663169
[[MTLCaptureManager sharedCaptureManager] stopCapture];
31673170
}
31683171
}
@@ -3514,7 +3517,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
35143517
}
35153518
}
35163519

3517-
// TODO: setting encode_async here causes crash. why?
3520+
// TODO: setting encode_async here causes crash during the next ggml_metal_graph_compute call. why?
35183521
//ctx->encode_async = ^(size_t iter) {
35193522
// ...
35203523
//};
@@ -3598,7 +3601,7 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
35983601
GGML_ASSERT(ggml_backend_is_metal(backend));
35993602

36003603
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
3601-
ctx->should_capture_next_compute = true;
3604+
ctx->capture_next_compute = true;
36023605
}
36033606

36043607
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning

0 commit comments

Comments
 (0)