@@ -379,7 +379,6 @@ void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) {
379379 id <MTLLibrary > library;
380380 id <MTLCommandQueue > queue; // currently a pointer to the device queue, but might become separate queue [TAG_QUEUE_PER_BACKEND]
381381
382- // struct ggml_metal_device_props props_dev;
383382 ggml_metal_device_t ctx_dev;
384383
385384 dispatch_queue_t d_queue;
@@ -1062,15 +1061,47 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te
10621061 }
10631062}
10641063
1065- struct ggml_metal_encode_context {
1066- id <MTLComputeCommandEncoder > encoder;
1067-
1064+ struct ggml_metal_encoder {
10681065 ggml_metal_t ctx;
10691066
1067+ id <MTLComputeCommandEncoder > encoder;
1068+
10701069 ggml_mem_ranges_t mem_ranges;
10711070};
10721071
1073- static bool ggml_metal_encode_concurrency_reset (struct ggml_metal_encode_context * ctx) {
1072+ ggml_metal_encoder_t ggml_metal_encoder_init (ggml_metal_t ctx, int cb_idx) {
1073+ ggml_metal_encoder_t res = calloc (1 , sizeof (struct ggml_metal_encoder));
1074+ res->ctx = ctx;
1075+
1076+ id <MTLCommandBuffer > cmd_buf = [ctx->que
1077+
1078+ if (ctx->use_concurrency) {
1079+ res->encoder = [ctx->queue computeCommandEncoder ];
1080+ res->mem_ranges = ggml_mem_ranges_init (ctx->debug_graph );
1081+ } else {
1082+ res->mem_ranges = nil ;
1083+ }
1084+
1085+ }
1086+
1087+
1088+ void ggml_metal_encoder_free (ggml_metal_encoder_t ctx);
1089+
1090+ void ggml_metal_encoder_begin (ggml_metal_encoder_t ctx, int idx) {
1091+ if (ctx->ctx ->capture_next_compute ) {
1092+ [ctx->encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
1093+ }
1094+ }
1095+
1096+ void ggml_metal_encoder_end (ggml_metal_encoder_t ctx, int idx) {
1097+ if (ctx->ctx ->capture_next_compute ) {
1098+ [ctx->encoder popDebugGroup ];
1099+ }
1100+
1101+ GGML_UNUSED (idx);
1102+ }
1103+
1104+ bool ggml_metal_encoder_concurrency_reset (struct ggml_metal_encoder * ctx) {
10741105 if (!ctx->mem_ranges ) {
10751106 return true ;
10761107 }
@@ -1082,23 +1113,23 @@ static bool ggml_metal_encode_concurrency_reset(struct ggml_metal_encode_context
10821113 return true ;
10831114}
10841115
1085- static bool ggml_metal_encode_concurrency_check (struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
1116+ bool ggml_metal_encoder_concurrency_check (struct ggml_metal_encoder * ctx, const struct ggml_tensor * node) {
10861117 if (!ctx->mem_ranges ) {
10871118 return false ;
10881119 }
10891120
10901121 return ggml_mem_ranges_check (ctx->mem_ranges , node);
10911122}
10921123
1093- static bool ggml_metal_encode_concurrency_add (struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
1124+ bool ggml_metal_encoder_concurrency_add (struct ggml_metal_encoder * ctx, const struct ggml_tensor * node) {
10941125 if (!ctx->mem_ranges ) {
10951126 return true ;
10961127 }
10971128
10981129 return ggml_mem_ranges_add (ctx->mem_ranges , node);
10991130}
11001131
1101- static int ggml_metal_encode_node (struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
1132+ static int ggml_metal_encoder_node (struct ggml_metal_encoder * ctx_enc, int idx, int idx_end) {
11021133 id <MTLComputeCommandEncoder > encoder = ctx_enc->encoder ;
11031134
11041135 ggml_metal_t ctx = ctx_enc->ctx ;
@@ -1221,10 +1252,10 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
12211252 // otherwise, we add the new ranges to the encoding context and process the node concurrently
12221253 //
12231254 {
1224- const bool is_concurrent = ggml_metal_encode_concurrency_check (ctx_enc, node);
1255+ const bool is_concurrent = ggml_metal_encoder_concurrency_check (ctx_enc, node);
12251256
12261257 if (!is_concurrent) {
1227- ggml_metal_encode_concurrency_reset (ctx_enc);
1258+ ggml_metal_encoder_concurrency_reset (ctx_enc);
12281259 }
12291260
12301261 if (ctx->debug_graph > 0 ) {
@@ -1407,8 +1438,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
14071438 id_dst = ggml_metal_get_buffer (nodes[n_fuse - 1 ], &offs_dst);
14081439
14091440 for (int i = 1 ; i < n_fuse; ++i) {
1410- if (!ggml_metal_encode_concurrency_check (ctx_enc, nodes[i])) {
1411- ggml_metal_encode_concurrency_reset (ctx_enc);
1441+ if (!ggml_metal_encoder_concurrency_check (ctx_enc, nodes[i])) {
1442+ ggml_metal_encoder_concurrency_reset (ctx_enc);
14121443
14131444 break ;
14141445 }
@@ -1557,7 +1588,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
15571588
15581589 [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
15591590
1560- ggml_metal_encode_concurrency_reset (ctx_enc);
1591+ ggml_metal_encoder_concurrency_reset (ctx_enc);
15611592 }
15621593
15631594 ggml_metal_kargs_bin args = {
@@ -3025,7 +3056,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
30253056 }
30263057
30273058 // this barrier is always needed because the next kernel has to wait for the id maps to be computed
3028- ggml_metal_encode_concurrency_reset (ctx_enc);
3059+ ggml_metal_encoder_concurrency_reset (ctx_enc);
30293060
30303061 {
30313062 id <MTLComputePipelineState > pipeline = nil ;
@@ -3497,8 +3528,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
34973528 id_dst = ggml_metal_get_buffer (nodes[n_fuse - 1 ], &offs_dst);
34983529
34993530 for (int i = 1 ; i < n_fuse; ++i) {
3500- if (!ggml_metal_encode_concurrency_check (ctx_enc, nodes[i])) {
3501- ggml_metal_encode_concurrency_reset (ctx_enc);
3531+ if (!ggml_metal_encoder_concurrency_check (ctx_enc, nodes[i])) {
3532+ ggml_metal_encoder_concurrency_reset (ctx_enc);
35023533
35033534 break ;
35043535 }
@@ -4404,7 +4435,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
44044435 [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03*nwg) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
44054436
44064437 // sync the 2 kernels
4407- ggml_metal_encode_concurrency_reset (ctx_enc);
4438+ ggml_metal_encoder_concurrency_reset (ctx_enc);
44084439
44094440 // reduce the results from the workgroups
44104441 {
@@ -4678,8 +4709,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
46784709
46794710 // update the mem ranges in the encoding context
46804711 for (int i = 0 ; i < n_fuse; ++i) {
4681- if (!ggml_metal_encode_concurrency_add (ctx_enc, nodes[i])) {
4682- ggml_metal_encode_concurrency_reset (ctx_enc);
4712+ if (!ggml_metal_encoder_concurrency_add (ctx_enc, nodes[i])) {
4713+ ggml_metal_encoder_concurrency_reset (ctx_enc);
46834714 }
46844715 }
46854716
@@ -4900,9 +4931,9 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
49004931
49014932 const bool should_capture = ctx->capture_next_compute ;
49024933
4903- struct ggml_metal_encode_context ctx_enc = {
4904- /* .encoder =*/ encoder,
4934+ struct ggml_metal_encoder ctx_enc = {
49054935 /* .ctx =*/ ctx,
4936+ /* .encoder =*/ encoder,
49064937 /* .mem_ranges =*/ mem_ranges,
49074938 };
49084939
@@ -4911,7 +4942,7 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
49114942 [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
49124943 }
49134944
4914- const int res = ggml_metal_encode_node (&ctx_enc, idx, node_end);
4945+ const int res = ggml_metal_encoder_node (&ctx_enc, idx, node_end);
49154946 if (idx + res > node_end) {
49164947 GGML_ABORT (" fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s " ,
49174948 " https://github.com/ggml-org/llama.cpp/pull/14849" );
0 commit comments