Skip to content

Commit 6d78c5e

Browse files
committed
wip [no ci]
1 parent d814b70 commit 6d78c5e

File tree

2 files changed

+70
-22
lines changed

2 files changed

+70
-22
lines changed

ggml/src/ggml-metal/ggml-metal-context.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,23 @@ void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort
6363
bool ggml_metal_supports_family (ggml_metal_t ctx, int family);
6464
void ggml_metal_capture_next_compute(ggml_metal_t ctx);
6565

66+
//
67+
// encoder
68+
//
69+
70+
typedef struct ggml_metal_encoder * ggml_metal_encoder_t;
71+
72+
ggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_t ctx, int cb_idx);
73+
void ggml_metal_encoder_free(ggml_metal_encoder_t ctx);
74+
75+
void ggml_metal_encoder_begin (ggml_metal_encoder_t ctx, int idx);
76+
void ggml_metal_encoder_encode(ggml_metal_encoder_t ctx, int idx, int node_end);
77+
void ggml_metal_encoder_end (ggml_metal_encoder_t ctx, int idx);
78+
79+
bool ggml_metal_encoder_concurrency_reset(ggml_metal_encoder_t ctx);
80+
bool ggml_metal_encoder_concurrency_check(ggml_metal_encoder_t ctx, const struct ggml_tensor * node);
81+
bool ggml_metal_encoder_concurrency_add (ggml_metal_encoder_t ctx, const struct ggml_tensor * node);
82+
6683
#ifdef __cplusplus
6784
}
6885
#endif

ggml/src/ggml-metal/ggml-metal-context.m

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,6 @@ void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) {
379379
id<MTLLibrary> library;
380380
id<MTLCommandQueue> queue; // currently a pointer to the device queue, but might become separate queue [TAG_QUEUE_PER_BACKEND]
381381

382-
//struct ggml_metal_device_props props_dev;
383382
ggml_metal_device_t ctx_dev;
384383

385384
dispatch_queue_t d_queue;
@@ -1062,15 +1061,47 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te
10621061
}
10631062
}
10641063

1065-
struct ggml_metal_encode_context {
1066-
id<MTLComputeCommandEncoder> encoder;
1067-
1064+
struct ggml_metal_encoder {
10681065
ggml_metal_t ctx;
10691066

1067+
id<MTLComputeCommandEncoder> encoder;
1068+
10701069
ggml_mem_ranges_t mem_ranges;
10711070
};
10721071

1073-
static bool ggml_metal_encode_concurrency_reset(struct ggml_metal_encode_context * ctx) {
1072+
ggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_t ctx, int cb_idx) {
1073+
ggml_metal_encoder_t res = calloc(1, sizeof(struct ggml_metal_encoder));
1074+
res->ctx = ctx;
1075+
1076+
id<MTLCommandBuffer> cmd_buf = [ctx->que
1077+
1078+
if (ctx->use_concurrency) {
1079+
res->encoder = [ctx->queue computeCommandEncoder];
1080+
res->mem_ranges = ggml_mem_ranges_init(ctx->debug_graph);
1081+
} else {
1082+
res->mem_ranges = nil;
1083+
}
1084+
1085+
}
1086+
1087+
1088+
void ggml_metal_encoder_free(ggml_metal_encoder_t ctx);
1089+
1090+
void ggml_metal_encoder_begin(ggml_metal_encoder_t ctx, int idx) {
1091+
if (ctx->ctx->capture_next_compute) {
1092+
[ctx->encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
1093+
}
1094+
}
1095+
1096+
void ggml_metal_encoder_end(ggml_metal_encoder_t ctx, int idx) {
1097+
if (ctx->ctx->capture_next_compute) {
1098+
[ctx->encoder popDebugGroup];
1099+
}
1100+
1101+
GGML_UNUSED(idx);
1102+
}
1103+
1104+
bool ggml_metal_encoder_concurrency_reset(struct ggml_metal_encoder * ctx) {
10741105
if (!ctx->mem_ranges) {
10751106
return true;
10761107
}
@@ -1082,23 +1113,23 @@ static bool ggml_metal_encode_concurrency_reset(struct ggml_metal_encode_context
10821113
return true;
10831114
}
10841115

1085-
static bool ggml_metal_encode_concurrency_check(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
1116+
bool ggml_metal_encoder_concurrency_check(struct ggml_metal_encoder * ctx, const struct ggml_tensor * node) {
10861117
if (!ctx->mem_ranges) {
10871118
return false;
10881119
}
10891120

10901121
return ggml_mem_ranges_check(ctx->mem_ranges, node);
10911122
}
10921123

1093-
static bool ggml_metal_encode_concurrency_add(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
1124+
bool ggml_metal_encoder_concurrency_add(struct ggml_metal_encoder * ctx, const struct ggml_tensor * node) {
10941125
if (!ctx->mem_ranges) {
10951126
return true;
10961127
}
10971128

10981129
return ggml_mem_ranges_add(ctx->mem_ranges, node);
10991130
}
11001131

1101-
static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
1132+
static int ggml_metal_encoder_node(struct ggml_metal_encoder * ctx_enc, int idx, int idx_end) {
11021133
id<MTLComputeCommandEncoder> encoder = ctx_enc->encoder;
11031134

11041135
ggml_metal_t ctx = ctx_enc->ctx;
@@ -1221,10 +1252,10 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
12211252
// otherwise, we add the new ranges to the encoding context and process the node concurrently
12221253
//
12231254
{
1224-
const bool is_concurrent = ggml_metal_encode_concurrency_check(ctx_enc, node);
1255+
const bool is_concurrent = ggml_metal_encoder_concurrency_check(ctx_enc, node);
12251256

12261257
if (!is_concurrent) {
1227-
ggml_metal_encode_concurrency_reset(ctx_enc);
1258+
ggml_metal_encoder_concurrency_reset(ctx_enc);
12281259
}
12291260

12301261
if (ctx->debug_graph > 0) {
@@ -1407,8 +1438,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
14071438
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
14081439

14091440
for (int i = 1; i < n_fuse; ++i) {
1410-
if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) {
1411-
ggml_metal_encode_concurrency_reset(ctx_enc);
1441+
if (!ggml_metal_encoder_concurrency_check(ctx_enc, nodes[i])) {
1442+
ggml_metal_encoder_concurrency_reset(ctx_enc);
14121443

14131444
break;
14141445
}
@@ -1557,7 +1588,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
15571588

15581589
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
15591590

1560-
ggml_metal_encode_concurrency_reset(ctx_enc);
1591+
ggml_metal_encoder_concurrency_reset(ctx_enc);
15611592
}
15621593

15631594
ggml_metal_kargs_bin args = {
@@ -3025,7 +3056,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
30253056
}
30263057

30273058
// this barrier is always needed because the next kernel has to wait for the id maps to be computed
3028-
ggml_metal_encode_concurrency_reset(ctx_enc);
3059+
ggml_metal_encoder_concurrency_reset(ctx_enc);
30293060

30303061
{
30313062
id<MTLComputePipelineState> pipeline = nil;
@@ -3497,8 +3528,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
34973528
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
34983529

34993530
for (int i = 1; i < n_fuse; ++i) {
3500-
if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) {
3501-
ggml_metal_encode_concurrency_reset(ctx_enc);
3531+
if (!ggml_metal_encoder_concurrency_check(ctx_enc, nodes[i])) {
3532+
ggml_metal_encoder_concurrency_reset(ctx_enc);
35023533

35033534
break;
35043535
}
@@ -4404,7 +4435,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
44044435
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
44054436

44064437
// sync the 2 kernels
4407-
ggml_metal_encode_concurrency_reset(ctx_enc);
4438+
ggml_metal_encoder_concurrency_reset(ctx_enc);
44084439

44094440
// reduce the results from the workgroups
44104441
{
@@ -4678,8 +4709,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
46784709

46794710
// update the mem ranges in the encoding context
46804711
for (int i = 0; i < n_fuse; ++i) {
4681-
if (!ggml_metal_encode_concurrency_add(ctx_enc, nodes[i])) {
4682-
ggml_metal_encode_concurrency_reset(ctx_enc);
4712+
if (!ggml_metal_encoder_concurrency_add(ctx_enc, nodes[i])) {
4713+
ggml_metal_encoder_concurrency_reset(ctx_enc);
46834714
}
46844715
}
46854716

@@ -4900,9 +4931,9 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
49004931

49014932
const bool should_capture = ctx->capture_next_compute;
49024933

4903-
struct ggml_metal_encode_context ctx_enc = {
4904-
/*.encoder =*/ encoder,
4934+
struct ggml_metal_encoder ctx_enc = {
49054935
/*.ctx =*/ ctx,
4936+
/*.encoder =*/ encoder,
49064937
/*.mem_ranges =*/ mem_ranges,
49074938
};
49084939

@@ -4911,7 +4942,7 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
49114942
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
49124943
}
49134944

4914-
const int res = ggml_metal_encode_node(&ctx_enc, idx, node_end);
4945+
const int res = ggml_metal_encoder_node(&ctx_enc, idx, node_end);
49154946
if (idx + res > node_end) {
49164947
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
49174948
"https://github.com/ggml-org/llama.cpp/pull/14849");

0 commit comments

Comments
 (0)