Skip to content

Commit 6d5cb56

Browse files
committed
wip
1 parent d7b5934 commit 6d5cb56

File tree

2 files changed

+51
-20
lines changed

2 files changed

+51
-20
lines changed

examples/perf-metal/perf-metal.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,10 @@ int main(int argc, char ** argv) {
107107
if (n_thread == 4) {
108108
ggml_backend_metal_capture_next_compute(backend);
109109
ggml_backend_graph_compute(backend, gf);
110+
//std::this_thread::sleep_for(std::chrono::milliseconds(1000)); // NOTE: these intervals do not appear in the XCode trace!
110111
ggml_backend_metal_capture_next_compute(backend);
111112
ggml_backend_graph_compute(backend, gf);
113+
//std::this_thread::sleep_for(std::chrono::milliseconds(1000)); // NOTE: these intervals do not appear in the XCode trace!
112114
ggml_backend_metal_capture_next_compute(backend);
113115
ggml_backend_graph_compute(backend, gf);
114116

ggml/src/ggml-metal.m

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,8 @@
236236
bool should_capture_next_compute;
237237
bool capture_started;
238238

239+
id<MTLCaptureScope> cap_scope;
240+
239241
// abort ggml_metal_graph_compute if callback returns true
240242
ggml_abort_callback abort_callback;
241243
void * abort_callback_data;
@@ -459,6 +461,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
459461
ctx->should_capture_next_compute = false;
460462
ctx->capture_started = false;
461463

464+
ctx->cap_scope = nil;
465+
462466
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
463467
if (@available(macOS 10.12, iOS 16.0, *)) {
464468
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
@@ -887,17 +891,21 @@ static enum ggml_status ggml_metal_graph_compute(
887891
// create multiple command buffers and enqueue them
888892
// then, we encode the graph into the command buffers in parallel
889893

890-
const int n_nodes = gf->n_nodes;
894+
const int n_nodes_0 = MIN(64, gf->n_nodes);
895+
const int n_nodes_1 = gf->n_nodes - n_nodes_0;
891896
const int n_cb = ctx->n_cb;
892-
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
897+
const int n_nodes_per_cb = (n_nodes_1 + n_cb - 1) / n_cb;
893898

894899
const bool should_capture = ctx->should_capture_next_compute;
895900
if (should_capture) {
896901
ctx->should_capture_next_compute = false;
897902

898903
if (!ctx->capture_started) {
904+
// create capture scope
905+
ctx->cap_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
906+
899907
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
900-
descriptor.captureObject = ctx->queue;
908+
descriptor.captureObject = ctx->cap_scope;
901909
descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
902910
descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
903911

@@ -906,26 +914,17 @@ static enum ggml_status ggml_metal_graph_compute(
906914
GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
907915
GGML_ABORT("capture failed");
908916
} else {
917+
[ctx->cap_scope beginScope];
909918
ctx->capture_started = true;
910919
}
911920
}
912921
}
913922

914-
id<MTLCommandBuffer> command_buffer_builder[n_cb];
915-
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
916-
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
917-
command_buffer_builder[cb_idx] = command_buffer;
918-
919-
// always enqueue the first two command buffers
920-
// enqueue all of the command buffers if we don't need to abort
921-
if (cb_idx < 2 || ctx->abort_callback == NULL) {
922-
[command_buffer enqueue];
923-
}
924-
}
925-
926-
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
923+
id<MTLCommandBuffer> command_buffer_builder[n_cb + 1];
924+
const id<MTLCommandBuffer> * command_buffers = command_buffer_builder;
927925

928-
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
926+
//dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
927+
void (^helper)(size_t iter) = ^(size_t iter) {
929928
const int cb_idx = iter;
930929

931930
size_t offs_src0 = 0;
@@ -936,8 +935,12 @@ static enum ggml_status ggml_metal_graph_compute(
936935
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
937936
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
938937

939-
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
940-
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
938+
int node_start = 0;
939+
int node_end = n_nodes_0;
940+
if ((int) iter < n_cb) {
941+
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
942+
node_end = n_nodes_0 + (MIN((cb_idx == n_cb - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
943+
}
941944

942945
for (int i = node_start; i < node_end; ++i) {
943946
if (i == -1) {
@@ -3037,11 +3040,36 @@ static enum ggml_status ggml_metal_graph_compute(
30373040
if (cb_idx < 2 || ctx->abort_callback == NULL) {
30383041
[command_buffer commit];
30393042
}
3040-
});
3043+
};
3044+
3045+
{
3046+
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
3047+
command_buffer_builder[n_cb] = command_buffer;
3048+
[command_buffer enqueue];
3049+
helper(n_cb);
3050+
}
3051+
3052+
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
3053+
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
3054+
command_buffer_builder[cb_idx] = command_buffer;
3055+
3056+
// always enqueue the first two command buffers
3057+
// enqueue all of the command buffers if we don't need to abort
3058+
if (cb_idx < 2 || ctx->abort_callback == NULL) {
3059+
[command_buffer enqueue];
3060+
}
3061+
}
3062+
3063+
dispatch_apply(n_cb, ctx->d_queue, helper);
30413064

30423065
// Wait for completion and check status of each command buffer
30433066
// needed to detect if the device ran out-of-memory for example (#1881)
30443067

3068+
{
3069+
id<MTLCommandBuffer> command_buffer = command_buffers[n_cb];
3070+
[command_buffer waitUntilCompleted];
3071+
}
3072+
30453073
for (int i = 0; i < n_cb; ++i) {
30463074
id<MTLCommandBuffer> command_buffer = command_buffers[i];
30473075
[command_buffer waitUntilCompleted];
@@ -3075,6 +3103,7 @@ static enum ggml_status ggml_metal_graph_compute(
30753103
}
30763104

30773105
if (!should_capture && ctx->capture_started) {
3106+
[ctx->cap_scope endScope];
30783107
[[MTLCaptureManager sharedCaptureManager] stopCapture];
30793108
}
30803109

0 commit comments

Comments
 (0)