33#import " ggml-impl.h"
44#import " ggml-backend-impl.h"
55#import " ggml-metal-impl.h"
6+ #import " ggml-metal-common.h"
67
78#import < Foundation/Foundation.h>
89
@@ -2075,42 +2076,20 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
20752076 }
20762077}
20772078
2078- #define MEM_RANGE_MAX 128
2079-
20802079struct ggml_metal_encode_context {
20812080 ggml_backend_t backend;
20822081
20832082 id <MTLComputeCommandEncoder > encoder;
20842083
20852084 struct ggml_metal_mem_pool * mem_pool;
20862085
2087- int n_ranges;
2088-
2089- struct mem_range {
2090- uint64_t p0; // being
2091- uint64_t p1; // end
2092- int pt; // type: 0 - src, 1 - dst
2093- } ranges[MEM_RANGE_MAX];
2094-
2095- int debug;
2086+ struct ggml_mem_ranges * mem_ranges;
20962087};
20972088
20982089static bool ggml_metal_encode_mem_ranges_reset (struct ggml_metal_encode_context * ctx) {
20992090 [ctx->encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
21002091
2101- ctx->n_ranges = 0 ;
2102-
2103- return true ;
2104- }
2105-
2106- static bool ggml_metal_encode_mem_ranges_add (struct ggml_metal_encode_context * ctx, struct mem_range r) {
2107- if (ctx->n_ranges == MEM_RANGE_MAX) {
2108- return false ;
2109- }
2110-
2111- ctx->ranges [ctx->n_ranges] = r;
2112-
2113- ctx->n_ranges ++;
2092+ ggml_mem_ranges_reset (ctx->mem_ranges );
21142093
21152094 return true ;
21162095}
@@ -2120,92 +2099,27 @@ static bool ggml_metal_encode_mem_ranges_add_src(struct ggml_metal_encode_contex
21202099 return true ;
21212100 }
21222101
2123- struct mem_range r = {
2124- /* .p0 =*/ (uint64_t ) node->data ,
2125- /* .p1 =*/ (uint64_t ) node->data + ggml_nbytes (node),
2126- /* .pt =*/ 0 ,
2127- };
2128-
2129- if (ctx->debug > 2 ) {
2130- GGML_LOG_DEBUG (" %s : add src range [%lld , %lld )\n " , __func__, r.p0 , r.p1 );
2131- }
2132-
2133- return ggml_metal_encode_mem_ranges_add (ctx, r);
2102+ return ggml_mem_ranges_add_src (ctx->mem_ranges , node);
21342103}
21352104
21362105static bool ggml_metal_encode_mem_ranges_add_dst (struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
21372106 GGML_ASSERT (node);
21382107
2139- struct mem_range r = {
2140- /* .p0 =*/ (uint64_t ) node->data ,
2141- /* .p1 =*/ (uint64_t ) node->data + ggml_nbytes (node),
2142- /* .pt =*/ 1 ,
2143- };
2144-
2145- if (ctx->debug > 2 ) {
2146- GGML_LOG_DEBUG (" %s : add dst range [%lld , %lld )\n " , __func__, r.p0 , r.p1 );
2147- }
2148-
2149- return ggml_metal_encode_mem_ranges_add (ctx, r);
2150- }
2151-
2152- // return true if:
2153- // - new src range overlaps with any existing dst range
2154- // - new dst range overlaps with any existing range (src or dst)
2155- static bool ggml_metal_encode_mem_ranges_check (const struct ggml_metal_encode_context * ctx, struct mem_range r) {
2156- for (int i = 0 ; i < ctx->n_ranges ; i++) {
2157- if (r.pt == 0 && ctx->ranges [i].pt == 0 ) {
2158- continue ;
2159- }
2160-
2161- if (r.p0 < ctx->ranges [i].p1 && r.p1 > ctx->ranges [i].p0 ) {
2162- return true ;
2163- }
2164- }
2165-
2166- return false ;
2108+ return ggml_mem_ranges_add_dst (ctx->mem_ranges , node);
21672109}
21682110
21692111static bool ggml_metal_encode_mem_ranges_check_src (const struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
21702112 if (!node) {
21712113 return false ;
21722114 }
21732115
2174- struct mem_range r = {
2175- /* .p0 =*/ (uint64_t ) node->data ,
2176- /* .p1 =*/ (uint64_t ) node->data + ggml_nbytes (node),
2177- /* .pt =*/ 0 ,
2178- };
2179-
2180- const bool res = ggml_metal_encode_mem_ranges_check (ctx, r);
2181-
2182- if (res) {
2183- if (ctx->debug > 2 ) {
2184- GGML_LOG_DEBUG (" %s : the src range [%lld , %lld ) overlaps with a previous dst range\n " , __func__, r.p0 , r.p1 );
2185- }
2186- }
2187-
2188- return res;
2116+ return ggml_mem_ranges_check_src (ctx->mem_ranges , node);
21892117}
21902118
21912119static bool ggml_metal_encode_mem_ranges_check_dst (const struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
21922120 GGML_ASSERT (node);
21932121
2194- struct mem_range r = {
2195- /* .p0 =*/ (uint64_t ) node->data ,
2196- /* .p1 =*/ (uint64_t ) node->data + ggml_nbytes (node),
2197- /* .pt =*/ 1 ,
2198- };
2199-
2200- const bool res = ggml_metal_encode_mem_ranges_check (ctx, r);
2201-
2202- if (res) {
2203- if (ctx->debug > 2 ) {
2204- GGML_LOG_DEBUG (" %s : the dst range [%lld , %lld ) overlaps with a previous src range\n " , __func__, r.p0 , r.p1 );
2205- }
2206- }
2207-
2208- return res;
2122+ return ggml_mem_ranges_check_dst (ctx->mem_ranges , node);
22092123}
22102124
22112125static int ggml_metal_encode_node (struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
@@ -6847,14 +6761,16 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
68476761 const bool should_capture = ctx->capture_next_compute ;
68486762
68496763 struct ggml_metal_encode_context ctx_enc = {
6850- /* .backend =*/ backend,
6851- /* .encoder =*/ encoder,
6852- /* .mem_pool =*/ mem_pool,
6853- /* .n_ranges =*/ 0 ,
6854- /* .ranges =*/ { 0 },
6855- /* .debug =*/ ctx_dev->debug_graph ,
6764+ /* .backend =*/ backend,
6765+ /* .encoder =*/ encoder,
6766+ /* .mem_pool =*/ mem_pool,
6767+ /* .mem_ranges =*/ NULL ,
68566768 };
68576769
6770+ if (ctx_dev->use_concurrency ) {
6771+ ctx_enc.mem_ranges = ggml_mem_ranges_init (ctx_dev->debug_graph );
6772+ }
6773+
68586774 for (int idx = node_start; idx < node_end;) {
68596775 if (should_capture) {
68606776 [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
@@ -6879,6 +6795,8 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
68796795
68806796 [encoder endEncoding ];
68816797
6798+ ggml_mem_ranges_free (ctx_enc.mem_ranges );
6799+
68826800 if (cb_idx < 2 || ctx->abort_callback == NULL ) {
68836801 [cmd_buf commit ];
68846802 }
0 commit comments