@@ -2092,47 +2092,138 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
20922092 int pt; // type: 0 - src, 1 - dst
20932093 } ranges[MEM_RANGE_MAX];
20942094
2095+ int debug;
20952096};
20962097
2097- static bool ggml_metal_encode_reset_mem_ranges (struct ggml_metal_encode_context * ctx_enc) {
2098- ctx_enc->n_ranges = 0 ;
2098+ static bool ggml_metal_encode_mem_ranges_reset (struct ggml_metal_encode_context * ctx) {
2099+ [ctx->encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
2100+
2101+ ctx->n_ranges = 0 ;
20992102
21002103 return true ;
21012104}
21022105
2103- static bool ggml_metal_encode_add_mem_range (struct ggml_metal_encode_context * ctx_enc , struct mem_range r) {
2104- if (ctx_enc ->n_ranges == MEM_RANGE_MAX) {
2106+ static bool ggml_metal_encode_mem_ranges_add (struct ggml_metal_encode_context * ctx , struct mem_range r) {
2107+ if (ctx ->n_ranges == MEM_RANGE_MAX) {
21052108 return false ;
21062109 }
21072110
2108- ctx_enc ->ranges [ctx_enc ->n_ranges] = r;
2111+ ctx ->ranges [ctx ->n_ranges] = r;
21092112
2110- ctx_enc ->n_ranges ++;
2113+ ctx ->n_ranges ++;
21112114
21122115 return true ;
21132116}
21142117
2115- // check and return true:
2116- // - if new range overlaps with any existing range of a different type
2117- // - if we are close to running out of range cells
2118- static bool ggml_metal_encode_check_mem_range (struct ggml_metal_encode_context * ctx_enc, struct mem_range r) {
2119- if (ctx_enc->n_ranges + 2 *GGML_MAX_SRC >= MEM_RANGE_MAX) {
2118+ static bool ggml_metal_encode_mem_ranges_add_src (struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
2119+ if (!node) {
21202120 return true ;
21212121 }
21222122
2123- for (int i = 0 ; i < ctx_enc->n_ranges ; i++) {
2124- if (ctx_enc->ranges [i].pt == r.pt ) {
2123+ size_t offs = 0 ;
2124+ id <MTLBuffer > id_node = ggml_metal_get_buffer (node, &offs);
2125+ GGML_ASSERT (id_node != nil );
2126+
2127+ struct mem_range r = {
2128+ /* .p0 =*/ id_node.gpuAddress + offs,
2129+ /* .p1 =*/ id_node.gpuAddress + offs + ggml_nbytes (node),
2130+ /* .pt =*/ 0 ,
2131+ };
2132+
2133+ if (ctx->debug > 2 ) {
2134+ GGML_LOG_DEBUG (" %s : add src range [%lld , %lld )\n " , __func__, r.p0 , r.p1 );
2135+ }
2136+
2137+ return ggml_metal_encode_mem_ranges_add (ctx, r);
2138+ }
2139+
2140+ static bool ggml_metal_encode_mem_ranges_add_dst (struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
2141+ GGML_ASSERT (node);
2142+
2143+ size_t offs = 0 ;
2144+ id <MTLBuffer > id_node = ggml_metal_get_buffer (node, &offs);
2145+ GGML_ASSERT (id_node != nil );
2146+
2147+ struct mem_range r = {
2148+ /* .p0 =*/ id_node.gpuAddress + offs,
2149+ /* .p1 =*/ id_node.gpuAddress + offs + ggml_nbytes (node),
2150+ /* .pt =*/ 1 ,
2151+ };
2152+
2153+ if (ctx->debug > 2 ) {
2154+ GGML_LOG_DEBUG (" %s : add dst range [%lld , %lld )\n " , __func__, r.p0 , r.p1 );
2155+ }
2156+
2157+ return ggml_metal_encode_mem_ranges_add (ctx, r);
2158+ }
2159+
2160+ // return true if:
2161+ // - new src range overlaps with any existing dst range
2162+ // - new dst range overlaps with any existing range (src or dst)
2163+ static bool ggml_metal_encode_mem_ranges_check (const struct ggml_metal_encode_context * ctx, struct mem_range r) {
2164+ for (int i = 0 ; i < ctx->n_ranges ; i++) {
2165+ if (r.pt == 0 && ctx->ranges [i].pt == 0 ) {
21252166 continue ;
21262167 }
21272168
2128- if (r.p0 < ctx_enc ->ranges [i].p1 && r.p1 > ctx_enc ->ranges [i].p0 ) {
2169+ if (r.p0 < ctx ->ranges [i].p1 && r.p1 > ctx ->ranges [i].p0 ) {
21292170 return true ;
21302171 }
21312172 }
21322173
21332174 return false ;
21342175}
21352176
2177+ static bool ggml_metal_encode_mem_ranges_check_src (const struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
2178+ if (!node) {
2179+ return false ;
2180+ }
2181+
2182+ size_t offs = 0 ;
2183+ id <MTLBuffer > id_node = ggml_metal_get_buffer (node, &offs);
2184+ GGML_ASSERT (id_node != nil );
2185+
2186+ struct mem_range r = {
2187+ /* .p0 =*/ id_node.gpuAddress + offs,
2188+ /* .p1 =*/ id_node.gpuAddress + offs + ggml_nbytes (node),
2189+ /* .pt =*/ 0 ,
2190+ };
2191+
2192+ const bool res = ggml_metal_encode_mem_ranges_check (ctx, r);
2193+
2194+ if (res) {
2195+ if (ctx->debug > 2 ) {
2196+ GGML_LOG_DEBUG (" %s : the src range [%lld , %lld ) overlaps with a previous dst range\n " , __func__, r.p0 , r.p1 );
2197+ }
2198+ }
2199+
2200+ return res;
2201+ }
2202+
2203+ static bool ggml_metal_encode_mem_ranges_check_dst (const struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
2204+ GGML_ASSERT (node);
2205+
2206+ size_t offs = 0 ;
2207+ id <MTLBuffer > id_node = ggml_metal_get_buffer (node, &offs);
2208+ GGML_ASSERT (id_node != nil );
2209+
2210+ struct mem_range r = {
2211+ /* .p0 =*/ id_node.gpuAddress + offs,
2212+ /* .p1 =*/ id_node.gpuAddress + offs + ggml_nbytes (node),
2213+ /* .pt =*/ 1 ,
2214+ };
2215+
2216+ const bool res = ggml_metal_encode_mem_ranges_check (ctx, r);
2217+
2218+ if (res) {
2219+ if (ctx->debug > 2 ) {
2220+ GGML_LOG_DEBUG (" %s : the dst range [%lld , %lld ) overlaps with a previous src range\n " , __func__, r.p0 , r.p1 );
2221+ }
2222+ }
2223+
2224+ return res;
2225+ }
2226+
21362227static int ggml_metal_encode_node (struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
21372228 ggml_backend_t backend = ctx_enc->backend ;
21382229
@@ -2254,94 +2345,47 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
22542345
22552346 int n_fuse = 1 ;
22562347
2348+ if (ctx_dev->debug_graph > 0 ) {
2349+ GGML_LOG_DEBUG (" %s : op - %s \n " , __func__, ggml_op_name (dst->op ));
2350+ if (src0) {
2351+ GGML_LOG_DEBUG (" %s : src0 - %4s [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], %d , %s \n " , __func__, ggml_type_name (src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
2352+ ggml_is_contiguous (src0), src0->name );
2353+ }
2354+ if (src1) {
2355+ GGML_LOG_DEBUG (" %s : src1 - %4s [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], %d , %s \n " , __func__, ggml_type_name (src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
2356+ ggml_is_contiguous (src1), src1->name );
2357+ }
2358+ if (dst) {
2359+ GGML_LOG_DEBUG (" %s : dst - %4s [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], 1, %s \n " , __func__, ggml_type_name (dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
2360+ dst->name );
2361+ }
2362+ }
2363+
22572364 // check if the current node can run concurrently with other nodes before it
22582365 // the condition is that:
2259- // - the current node cannot write to any previous src ranges
2366+ // - the current node cannot write to any previous src or dst ranges
22602367 // - the current node cannot read from any previous dst ranges
22612368 //
22622369 // if the condition is not satisfied, we put a memory barrier and clear all ranges
2263- // otherwise, we add the new ranges to the encoding context and add the node for concurrent execution
2370+ // otherwise, we add the new ranges to the encoding context and process the node concurrently
22642371 //
2265- bool is_concurrent = ctx_dev->use_concurrency ;
2372+ if (ctx_dev->use_concurrency ) {
2373+ bool is_concurrent = true ;
22662374
2267- if (is_concurrent) {
22682375 // do not read from any previous dst ranges
22692376 for (int i = 0 ; i < GGML_MAX_SRC; i++) {
2270- if (id_src[i] == nil ) {
2271- continue ;
2272- }
2273-
2274- struct mem_range r = {
2275- /* .p0 =*/ id_src[i].gpuAddress + offs_src[i],
2276- /* .p1 =*/ id_src[i].gpuAddress + offs_src[i] + ggml_nbytes (node->src [i]),
2277- /* .pt =*/ 0 ,
2278- };
2279-
2280- if (ggml_metal_encode_check_mem_range (ctx_enc, r)) {
2281- is_concurrent = false ;
2282-
2283- break ;
2284- }
2377+ is_concurrent = is_concurrent && !ggml_metal_encode_mem_ranges_check_src (ctx_enc, node->src [i]);
22852378 }
22862379
2287- // do not write to any previous src ranges
2288- if (is_concurrent) {
2289- struct mem_range r = {
2290- /* .p0 =*/ id_dst.gpuAddress + offs_dst,
2291- /* .p1 =*/ id_dst.gpuAddress + offs_dst + ggml_nbytes (dst),
2292- /* .pt =*/ 1 ,
2293- };
2294-
2295- if (ggml_metal_encode_check_mem_range (ctx_enc, r)) {
2296- is_concurrent = false ;
2297- }
2298- }
2380+ // do not write to any previous ranges
2381+ is_concurrent = is_concurrent && !ggml_metal_encode_mem_ranges_check_dst (ctx_enc, dst);
22992382
23002383 if (!is_concurrent) {
2301- ggml_metal_encode_reset_mem_ranges (ctx_enc);
2302-
2303- [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
2384+ ggml_metal_encode_mem_ranges_reset (ctx_enc);
23042385 }
23052386
2306- // add new ranges
2307- for (int i = 0 ; i < GGML_MAX_SRC; i++) {
2308- if (id_src[i] == nil ) {
2309- continue ;
2310- }
2311-
2312- struct mem_range r = {
2313- /* .p0 =*/ id_src[i].gpuAddress + offs_src[i],
2314- /* .p1 =*/ id_src[i].gpuAddress + offs_src[i] + ggml_nbytes (node->src [i]),
2315- /* .pt =*/ 0 ,
2316- };
2317-
2318- ggml_metal_encode_add_mem_range (ctx_enc, r);
2319- }
2320-
2321- {
2322- struct mem_range r = {
2323- /* .p0 =*/ id_dst.gpuAddress + offs_dst,
2324- /* .p1 =*/ id_dst.gpuAddress + offs_dst + ggml_nbytes (dst),
2325- /* .pt =*/ 1 ,
2326- };
2327-
2328- ggml_metal_encode_add_mem_range (ctx_enc, r);
2329- }
2330- }
2331-
2332- if (ctx_dev->debug_graph > 0 ) {
2333- GGML_LOG_INFO (" %s : op - %s , concurrent = %d \n " , __func__, ggml_op_name (dst->op ), is_concurrent);
2334- if (src0) {
2335- GGML_LOG_INFO (" %s : src0 - %4s [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], %d , %s \n " , __func__, ggml_type_name (src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
2336- ggml_is_contiguous (src0), src0->name );
2337- }
2338- if (src1) {
2339- GGML_LOG_INFO (" %s : src1 - %4s [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], %d , %s \n " , __func__, ggml_type_name (src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
2340- ggml_is_contiguous (src1), src1->name );
2341- }
2342- if (dst) {
2343- GGML_LOG_INFO (" %s : dst - %4s [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], 1, %s \n " , __func__, ggml_type_name (dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
2344- dst->name );
2387+ if (ctx_dev->debug_graph > 0 ) {
2388+ GGML_LOG_DEBUG (" %s : concurrent = %d \n " , __func__, is_concurrent);
23452389 }
23462390 }
23472391
@@ -2544,6 +2588,22 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
25442588 id_dst = ggml_metal_get_buffer (nodes[n_fuse - 1 ], &offs_dst);
25452589 }
25462590
2591+ if (ctx_dev->use_concurrency && n_fuse > 1 ) {
2592+ bool is_concurrent = true ;
2593+
2594+ // make sure that none of the fused nodes reads from a previous dst range
2595+ for (int i = 1 ; i < n_fuse; ++i) {
2596+ is_concurrent = is_concurrent && !ggml_metal_encode_mem_ranges_check_src (ctx_enc, nodes[i]->src [1 ]);
2597+ }
2598+
2599+ // do not write to any previous range
2600+ is_concurrent = is_concurrent && !ggml_metal_encode_mem_ranges_check_dst (ctx_enc, nodes[n_fuse - 1 ]);
2601+
2602+ if (!is_concurrent) {
2603+ ggml_metal_encode_mem_ranges_reset (ctx_enc);
2604+ }
2605+ }
2606+
25472607 [encoder setComputePipelineState: pipeline];
25482608 [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
25492609 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
@@ -2688,7 +2748,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
26882748 [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
26892749
26902750 if (ctx_dev->use_concurrency ) {
2691- [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ] ;
2751+ ggml_metal_encode_mem_ranges_reset (ctx_enc) ;
26922752 }
26932753 }
26942754
@@ -4215,7 +4275,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
42154275 }
42164276
42174277 if (ctx_dev->use_concurrency ) {
4218- [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ] ;
4278+ ggml_metal_encode_mem_ranges_reset (ctx_enc) ;
42194279 }
42204280
42214281 {
@@ -4688,6 +4748,22 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
46884748 id_dst = ggml_metal_get_buffer (nodes[n_fuse - 1 ], &offs_dst);
46894749 }
46904750
4751+ if (ctx_dev->use_concurrency ) {
4752+ bool is_concurrent = true ;
4753+
4754+ // make sure that none of the fused nodes reads from a previous dst range
4755+ for (int i = 1 ; i < n_fuse; ++i) {
4756+ is_concurrent = is_concurrent && !ggml_metal_encode_mem_ranges_check_src (ctx_enc, nodes[i]->src [1 ]);
4757+ }
4758+
4759+ // do not write to any previous range
4760+ is_concurrent = is_concurrent && !ggml_metal_encode_mem_ranges_check_dst (ctx_enc, nodes[n_fuse - 1 ]);
4761+
4762+ if (!is_concurrent) {
4763+ ggml_metal_encode_mem_ranges_reset (ctx_enc);
4764+ }
4765+ }
4766+
46914767 id <MTLComputePipelineState > pipeline;
46924768
46934769 switch (n_fuse) {
@@ -5608,7 +5684,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
56085684 [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03*nwg) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
56095685
56105686 if (ctx_dev->use_concurrency ) {
5611- [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ] ;
5687+ ggml_metal_encode_mem_ranges_reset (ctx_enc) ;
56125688 }
56135689
56145690 // reduce the results from the workgroups
@@ -5875,28 +5951,28 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
58755951
58765952 if (ctx_dev->debug_graph > 0 ) {
58775953 if (n_fuse > 1 ) {
5878- GGML_LOG_INFO (" %s : fuse: %d ops\n " , __func__, n_fuse);
5954+ GGML_LOG_DEBUG (" %s : fuse: %d ops\n " , __func__, n_fuse);
58795955 }
58805956 }
58815957
5882- // after fusing, we have to add the new destination range to the encoding context
5883- if (ctx_dev->use_concurrency && n_fuse > 1 ) {
5884- struct ggml_tensor * dstf = nodes[n_fuse - 1 ];
5885-
5886- size_t offs_dstf = 0 ;
5958+ // update the mem ranges in the encoding context
5959+ if (ctx_dev->use_concurrency ) {
5960+ bool ok = true ;
58875961
5888- id <MTLBuffer > id_dstf = dstf ? ggml_metal_get_buffer (dstf, &offs_dstf) : nil ;
5962+ // add new src ranges
5963+ for (int i = 0 ; i < GGML_MAX_SRC; i++) {
5964+ ok = ok && ggml_metal_encode_mem_ranges_add_src (ctx_enc, node->src [i]);
5965+ }
58895966
5890- struct mem_range r = {
5891- /* .p0 =*/ id_dstf.gpuAddress + offs_dstf,
5892- /* .p1 =*/ id_dstf.gpuAddress + offs_dstf + ggml_nbytes (dstf),
5893- /* .pt =*/ 1 ,
5894- };
5967+ // add the destination range
5968+ ok = ok && ggml_metal_encode_mem_ranges_add_dst (ctx_enc, nodes[n_fuse - 1 ]);
58955969
5896- if (!ggml_metal_encode_add_mem_range (ctx_enc, r)) {
5897- ggml_metal_encode_reset_mem_ranges (ctx_enc);
5970+ if (!ok) {
5971+ if (ctx_dev->debug_graph > 2 ) {
5972+ GGML_LOG_DEBUG (" %s : the range cache is full -> reset and put a barrier\n " , __func__);
5973+ }
58985974
5899- [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ] ;
5975+ ggml_metal_encode_mem_ranges_reset (ctx_enc) ;
59005976 }
59015977 }
59025978
@@ -6792,6 +6868,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
67926868 /* .mem_pool =*/ mem_pool,
67936869 /* .n_ranges =*/ 0 ,
67946870 /* .ranges =*/ { 0 },
6871+ /* .debug =*/ ctx_dev->debug_graph ,
67956872 };
67966873
67976874 for (int idx = node_start; idx < node_end;) {
0 commit comments