@@ -2063,12 +2063,71 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
20632063 }
20642064}
20652065
2066- static int ggml_metal_encode_node (
2067- ggml_backend_t backend,
2068- int idx,
2069- int idx_end,
2070- id <MTLComputeCommandEncoder > encoder,
2071- struct ggml_metal_mem_pool * mem_pool) {
2066+ #define MEM_RANGE_MAX 128
2067+
2068+ struct ggml_metal_encode_context {
2069+ ggml_backend_t backend;
2070+
2071+ id <MTLComputeCommandEncoder > encoder;
2072+
2073+ struct ggml_metal_mem_pool * mem_pool;
2074+
2075+ int n_ranges;
2076+
2077+ struct mem_range {
2078+ uint64_t p0; // being
2079+ uint64_t p1; // end
2080+ int pt; // type: 0 - src, 1 - dst
2081+ } ranges[MEM_RANGE_MAX];
2082+
2083+ };
2084+
2085+ static bool ggml_metal_encode_reset_mem_ranges (struct ggml_metal_encode_context * ctx_enc) {
2086+ ctx_enc->n_ranges = 0 ;
2087+
2088+ return true ;
2089+ }
2090+
2091+ static bool ggml_metal_encode_add_mem_range (struct ggml_metal_encode_context * ctx_enc, struct mem_range r) {
2092+ if (ctx_enc->n_ranges == MEM_RANGE_MAX) {
2093+ return false ;
2094+ }
2095+
2096+ ctx_enc->ranges [ctx_enc->n_ranges] = r;
2097+
2098+ ctx_enc->n_ranges ++;
2099+
2100+ return true ;
2101+ }
2102+
2103+ // check and return true:
2104+ // - if new range overlaps with any existing range of a different type
2105+ // - if we are close to running out of range cells
2106+ static bool ggml_metal_encode_check_mem_range (struct ggml_metal_encode_context * ctx_enc, struct mem_range r) {
2107+ if (ctx_enc->n_ranges + 2 *GGML_MAX_SRC >= MEM_RANGE_MAX) {
2108+ return true ;
2109+ }
2110+
2111+ for (int i = 0 ; i < ctx_enc->n_ranges ; i++) {
2112+ if (ctx_enc->ranges [i].pt == r.pt ) {
2113+ continue ;
2114+ }
2115+
2116+ if (r.p0 < ctx_enc->ranges [i].p1 && r.p1 > ctx_enc->ranges [i].p0 ) {
2117+ return true ;
2118+ }
2119+ }
2120+
2121+ return false ;
2122+ }
2123+
2124+ static int ggml_metal_encode_node (struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
2125+ ggml_backend_t backend = ctx_enc->backend ;
2126+
2127+ id <MTLComputeCommandEncoder > encoder = ctx_enc->encoder ;
2128+
2129+ struct ggml_metal_mem_pool * mem_pool = ctx_enc->mem_pool ;
2130+
20722131 struct ggml_backend_metal_context * ctx = backend->context ;
20732132 struct ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
20742133
@@ -2151,22 +2210,113 @@ static int ggml_metal_encode_node(
21512210 const uint64_t nb2 = dst ? dst->nb [2 ] : 0 ;
21522211 const uint64_t nb3 = dst ? dst->nb [3 ] : 0 ;
21532212
2213+ size_t offs_src[GGML_MAX_SRC];
2214+
2215+ id <MTLBuffer > id_src[GGML_MAX_SRC];
2216+
2217+ enum ggml_type srct[GGML_MAX_SRC];
2218+
2219+ for (int i = 0 ; i < GGML_MAX_SRC; i++) {
2220+ offs_src[i] = 0 ;
2221+ id_src[i] = node->src [i] ? ggml_metal_get_buffer (node->src [i], &offs_src[i]) : nil ;
2222+ srct[i] = node->src [i] ? node->src [i]->type : GGML_TYPE_COUNT;
2223+ }
2224+
2225+ size_t offs_dst = 0 ;
2226+
2227+ id <MTLBuffer > id_dst = dst ? ggml_metal_get_buffer (dst, &offs_dst) : nil ;
2228+
2229+ int n_fuse = 1 ;
2230+
2231+ // check if the current node can run concurrently with other nodes before it
2232+ // the condition is that:
2233+ // - the current node cannot write to any previous src ranges
2234+ // - the current node cannot read from any previous dst ranges
2235+ //
2236+ // if the condition is not satisfied, we put a memory barrier and clear all ranges
2237+ // otherwise, we add the new ranges to the encoding context and add the node for concurrent execution
2238+ //
2239+ {
2240+ bool is_concurrent = true ;
2241+
2242+ // do not read from any previous dst ranges
2243+ for (int i = 0 ; i < GGML_MAX_SRC; i++) {
2244+ if (id_src[i] == nil ) {
2245+ continue ;
2246+ }
2247+
2248+ struct mem_range r = {
2249+ /* .p0 =*/ id_src[i].gpuAddress + offs_src[i],
2250+ /* .p1 =*/ id_src[i].gpuAddress + offs_src[i] + ggml_nbytes (node->src [i]),
2251+ /* .pt =*/ 0 ,
2252+ };
2253+
2254+ if (ggml_metal_encode_check_mem_range (ctx_enc, r)) {
2255+ is_concurrent = false ;
2256+
2257+ break ;
2258+ }
2259+ }
2260+
2261+ // do not write to any previous src ranges
2262+ if (is_concurrent) {
2263+ struct mem_range r = {
2264+ /* .p0 =*/ id_dst.gpuAddress + offs_dst,
2265+ /* .p1 =*/ id_dst.gpuAddress + offs_dst + ggml_nbytes (dst),
2266+ /* .pt =*/ 1 ,
2267+ };
2268+
2269+ if (ggml_metal_encode_check_mem_range (ctx_enc, r)) {
2270+ is_concurrent = false ;
2271+ }
2272+ }
2273+
2274+ if (!is_concurrent) {
2275+ ggml_metal_encode_reset_mem_ranges (ctx_enc);
2276+
2277+ [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
2278+ }
2279+
2280+ // add new ranges
2281+ for (int i = 0 ; i < GGML_MAX_SRC; i++) {
2282+ if (id_src[i] == nil ) {
2283+ continue ;
2284+ }
2285+
2286+ struct mem_range r = {
2287+ /* .p0 =*/ id_src[i].gpuAddress + offs_src[i],
2288+ /* .p1 =*/ id_src[i].gpuAddress + offs_src[i] + ggml_nbytes (node->src [i]),
2289+ /* .pt =*/ 0 ,
2290+ };
2291+
2292+ ggml_metal_encode_add_mem_range (ctx_enc, r);
2293+ }
2294+
2295+ {
2296+ struct mem_range r = {
2297+ /* .p0 =*/ id_dst.gpuAddress + offs_dst,
2298+ /* .p1 =*/ id_dst.gpuAddress + offs_dst + ggml_nbytes (dst),
2299+ /* .pt =*/ 1 ,
2300+ };
2301+
2302+ ggml_metal_encode_add_mem_range (ctx_enc, r);
2303+ }
2304+ }
2305+
2306+ // TODO: tmp shorthands - remove
2307+ size_t offs_src0 = offs_src[0 ];
2308+ size_t offs_src1 = offs_src[1 ];
2309+ size_t offs_src2 = offs_src[2 ];
2310+
2311+ id <MTLBuffer > id_src0 = id_src[0 ];
2312+ id <MTLBuffer > id_src1 = id_src[1 ];
2313+ id <MTLBuffer > id_src2 = id_src[2 ];
2314+
21542315 const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
21552316 const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
21562317 const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
21572318 const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
21582319
2159- size_t offs_src0 = 0 ;
2160- size_t offs_src1 = 0 ;
2161- size_t offs_src2 = 0 ;
2162- size_t offs_dst = 0 ;
2163-
2164- id <MTLBuffer > id_src0 = src0 ? ggml_metal_get_buffer (src0, &offs_src0) : nil ;
2165- id <MTLBuffer > id_src1 = src1 ? ggml_metal_get_buffer (src1, &offs_src1) : nil ;
2166- id <MTLBuffer > id_src2 = src2 ? ggml_metal_get_buffer (src2, &offs_src2) : nil ;
2167- id <MTLBuffer > id_dst = dst ? ggml_metal_get_buffer (dst, &offs_dst) : nil ;
2168-
2169- int n_fuse = 1 ;
21702320
21712321#if 0
21722322 GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
@@ -2525,6 +2675,8 @@ static int ggml_metal_encode_node(
25252675 const int nth = MIN ((int ) pipeline.maxTotalThreadsPerThreadgroup , ne00);
25262676
25272677 [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
2678+
2679+ [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
25282680 }
25292681
25302682 const id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ;
@@ -4049,6 +4201,8 @@ static int ggml_metal_encode_node(
40494201 [encoder dispatchThreadgroups: MTLSizeMake (1 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (ne02, 1 , 1 )];
40504202 }
40514203
4204+ [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
4205+
40524206 {
40534207 id <MTLComputePipelineState > pipeline = nil ;
40544208
@@ -4660,7 +4814,6 @@ static int ggml_metal_encode_node(
46604814 } break ;
46614815 case GGML_OP_ROPE:
46624816 {
4663-
46644817 // make sure we have one or more position id(ne10) per token(ne02)
46654818 GGML_ASSERT (ne10 % ne02 == 0 );
46664819 GGML_ASSERT (ne10 >= ne02);
@@ -5439,6 +5592,8 @@ static int ggml_metal_encode_node(
54395592 [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
54405593 [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03*nwg) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
54415594
5595+ [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
5596+
54425597 // reduce the results from the workgroups
54435598 {
54445599 ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
@@ -5669,7 +5824,7 @@ static int ggml_metal_encode_node(
56695824
56705825 [encoder dispatchThreadgroups: MTLSizeMake (n_tg, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (n_threads, 1 , 1 )];
56715826 } break ;
5672- case GGML_OP_ARGMAX:
5827+ case GGML_OP_ARGMAX:
56735828 {
56745829 GGML_ASSERT (src0->type == GGML_TYPE_F32);
56755830 GGML_ASSERT (ggml_is_contiguous_1 (src0));
@@ -5701,6 +5856,27 @@ static int ggml_metal_encode_node(
57015856 }
57025857 }
57035858
5859+ // after fusing, we have to add the new destination range to the encoding context
5860+ if (n_fuse > 1 ) {
5861+ struct ggml_tensor * dstf = nodes[n_fuse - 1 ];
5862+
5863+ size_t offs_dstf = 0 ;
5864+
5865+ id <MTLBuffer > id_dstf = dstf ? ggml_metal_get_buffer (dstf, &offs_dstf) : nil ;
5866+
5867+ struct mem_range r = {
5868+ /* .p0 =*/ id_dstf.gpuAddress + offs_dstf,
5869+ /* .p1 =*/ id_dstf.gpuAddress + offs_dstf + ggml_nbytes (dstf),
5870+ /* .pt =*/ 1 ,
5871+ };
5872+
5873+ if (!ggml_metal_encode_add_mem_range (ctx_enc, r)) {
5874+ ggml_metal_encode_reset_mem_ranges (ctx_enc);
5875+
5876+ [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
5877+ }
5878+ }
5879+
57045880 return n_fuse;
57055881}
57065882
@@ -6567,7 +6743,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
65676743
65686744 ggml_metal_mem_pool_reset (mem_pool);
65696745
6570- id <MTLComputeCommandEncoder > encoder = [cmd_buf computeCommandEncoder ];
6746+ id <MTLComputeCommandEncoder > encoder = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent ];
65716747
65726748 int node_start = 0 ;
65736749 int node_end = n_nodes_0;
@@ -6579,12 +6755,20 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
65796755
65806756 const bool should_capture = ctx->capture_next_compute ;
65816757
6758+ struct ggml_metal_encode_context ctx_enc = {
6759+ /* .backend =*/ backend,
6760+ /* .encoder =*/ encoder,
6761+ /* .mem_pool =*/ mem_pool,
6762+ /* .n_ranges =*/ 0 ,
6763+ /* .ranges =*/ { 0 },
6764+ };
6765+
65826766 for (int idx = node_start; idx < node_end;) {
65836767 if (should_capture) {
65846768 [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
65856769 }
65866770
6587- const int res = ggml_metal_encode_node (backend , idx, node_end, encoder, mem_pool );
6771+ const int res = ggml_metal_encode_node (&ctx_enc , idx, node_end);
65886772 if (idx + res > node_end) {
65896773 GGML_ABORT (" fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s " ,
65906774 " https://github.com/ggml-org/llama.cpp/pull/14849" );
0 commit comments