6161    bool  has_bfloat;
6262    bool  use_bfloat;
6363    bool  use_fusion;
64+     bool  use_concurrency;
6465    bool  use_shared_buffers;
6566
67+     int  debug_graph;
6668    int  debug_fusion;
6769
6870    //  how many times a given op was fused
8385    /* .has_bfloat              =*/   false ,
8486    /* .use_bfloat              =*/   false ,
8587    /* .use_fusion              =*/   true ,
88+     /* .use_concurrency         =*/   true ,
8689    /* .use_shared_buffers      =*/   true ,
90+     /* .debug_graph             =*/   0 ,
8791    /* .debug_fusion            =*/   0 ,
8892    /* .fuse_cnt                =*/   { 0  },
8993    /* .max_size                =*/   0 ,
124128#else 
125129            ctx->use_bfloat  = false ;
126130#endif 
127-             ctx->use_fusion  = getenv (" GGML_METAL_FUSION_DISABLE"  ) == nil ;
131+ 
132+             ctx->use_fusion       = getenv (" GGML_METAL_FUSION_DISABLE"  ) == nil ;
133+             ctx->use_concurrency  = getenv (" GGML_METAL_CONCURRENCY_DISABLE"  ) == nil ;
134+ 
135+             {
136+                 const  char  * val = getenv (" GGML_METAL_GRAPH_DEBUG"  );
137+                 ctx->debug_graph  = val ? atoi (val) : 0 ;
138+             }
128139
129140            {
130141                const  char  * val = getenv (" GGML_METAL_FUSION_DEBUG"  );
@@ -1091,6 +1102,7 @@ @implementation GGMLMetalClass
10911102    GGML_LOG_INFO (" %s : has bfloat            = %s \n "  , __func__, ctx_dev->has_bfloat                   ? " true"   : " false"  );
10921103    GGML_LOG_INFO (" %s : use bfloat            = %s \n "  , __func__, ctx_dev->use_bfloat                   ? " true"   : " false"  );
10931104    GGML_LOG_INFO (" %s : use fusion            = %s \n "  , __func__, ctx_dev->use_fusion                   ? " true"   : " false"  );
1105+     GGML_LOG_INFO (" %s : use concurrency       = %s \n "  , __func__, ctx_dev->use_concurrency              ? " true"   : " false"  );
10941106    GGML_LOG_INFO (" %s : use shared buffers    = %s \n "  , __func__, ctx_dev->use_shared_buffers           ? " true"   : " false"  );
10951107    GGML_LOG_INFO (" %s : hasUnifiedMemory      = %s \n "  , __func__, ctx_dev->mtl_device .hasUnifiedMemory  ? " true"   : " false"  );
10961108
@@ -2222,6 +2234,20 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
22222234        srct[i]   = node->src [i] ? node->src [i]->type  : GGML_TYPE_COUNT;
22232235    }
22242236
2237+     //  TODO: tmp shorthands - remove
2238+     size_t  offs_src0 = offs_src[0 ];
2239+     size_t  offs_src1 = offs_src[1 ];
2240+     size_t  offs_src2 = offs_src[2 ];
2241+ 
2242+     id <MTLBuffer > id_src0 = id_src[0 ];
2243+     id <MTLBuffer > id_src1 = id_src[1 ];
2244+     id <MTLBuffer > id_src2 = id_src[2 ];
2245+ 
2246+     const  enum  ggml_type src0t = src0 ? src0->type  : GGML_TYPE_COUNT;
2247+     const  enum  ggml_type src1t = src1 ? src1->type  : GGML_TYPE_COUNT;
2248+     const  enum  ggml_type src2t = src2 ? src2->type  : GGML_TYPE_COUNT;
2249+     const  enum  ggml_type dstt  = dst  ? dst->type   : GGML_TYPE_COUNT;
2250+ 
22252251    size_t  offs_dst = 0 ;
22262252
22272253    id <MTLBuffer > id_dst = dst ? ggml_metal_get_buffer (dst, &offs_dst) : nil ;
@@ -2236,9 +2262,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
22362262    //  if the condition is not satisfied, we put a memory barrier and clear all ranges
22372263    //  otherwise, we add the new ranges to the encoding context and add the node for concurrent execution
22382264    // 
2239-     {
2240-         bool  is_concurrent = true ;
2265+     bool  is_concurrent = ctx_dev->use_concurrency ;
22412266
2267+     if  (is_concurrent) {
22422268        //  do not read from any previous dst ranges
22432269        for  (int  i = 0 ; i < GGML_MAX_SRC; i++) {
22442270            if  (id_src[i] == nil ) {
@@ -2303,36 +2329,21 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
23032329        }
23042330    }
23052331
2306-     //  TODO: tmp shorthands - remove
2307-     size_t  offs_src0 = offs_src[0 ];
2308-     size_t  offs_src1 = offs_src[1 ];
2309-     size_t  offs_src2 = offs_src[2 ];
2310- 
2311-     id <MTLBuffer > id_src0 = id_src[0 ];
2312-     id <MTLBuffer > id_src1 = id_src[1 ];
2313-     id <MTLBuffer > id_src2 = id_src[2 ];
2314- 
2315-     const  enum  ggml_type src0t = src0 ? src0->type  : GGML_TYPE_COUNT;
2316-     const  enum  ggml_type src1t = src1 ? src1->type  : GGML_TYPE_COUNT;
2317-     const  enum  ggml_type src2t = src2 ? src2->type  : GGML_TYPE_COUNT;
2318-     const  enum  ggml_type dstt  = dst  ? dst->type   : GGML_TYPE_COUNT;
2319- 
2320- 
2321- #if  0 
2322-     GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
2323-     if (src0) {
2324-         GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
2325-                 ggml_is_contiguous(src0), src0->name);
2326-     }
2327-     if (src1) {
2328-         GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
2329-                 ggml_is_contiguous(src1), src1->name);
2330-     }
2331-     if (dst) {
2332-         GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
2333-                 dst->name);
2332+     if  (ctx_dev->debug_graph  > 0 ) {
2333+         GGML_LOG_INFO (" %s : op - %s , concurrent = %d \n "  , __func__, ggml_op_name (dst->op ), is_concurrent);
2334+         if  (src0) {
2335+             GGML_LOG_INFO (" %s : src0 - %4s  [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], %d , %s \n "  , __func__, ggml_type_name (src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
2336+                     ggml_is_contiguous (src0), src0->name );
2337+         }
2338+         if  (src1) {
2339+             GGML_LOG_INFO (" %s : src1 - %4s  [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], %d , %s \n "  , __func__, ggml_type_name (src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
2340+                     ggml_is_contiguous (src1), src1->name );
2341+         }
2342+         if  (dst) {
2343+             GGML_LOG_INFO (" %s : dst  - %4s  [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], 1, %s \n "  , __func__, ggml_type_name (dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
2344+                     dst->name );
2345+         }
23342346    }
2335- #endif 
23362347
23372348    id <MTLDevice > device = ctx_dev->mtl_device ;
23382349
@@ -2676,7 +2687,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
26762687
26772688                    [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
26782689
2679-                     [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
2690+                     if  (ctx_dev->use_concurrency ) {
2691+                         [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
2692+                     }
26802693                }
26812694
26822695                const  id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ;
@@ -4201,7 +4214,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
42014214                        [encoder dispatchThreadgroups: MTLSizeMake (1 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (ne02, 1 , 1 )];
42024215                    }
42034216
4204-                     [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
4217+                     if  (ctx_dev->use_concurrency ) {
4218+                         [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
4219+                     }
42054220
42064221                    {
42074222                        id <MTLComputePipelineState > pipeline = nil ;
@@ -5592,7 +5607,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55925607                        [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
55935608                        [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nqptg - 1 )/nqptg, ne02, ne03*nwg) threadsPerThreadgroup: MTLSizeMake (32 , nsg, 1 )];
55945609
5595-                         [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
5610+                         if  (ctx_dev->use_concurrency ) {
5611+                             [encoder memoryBarrierWithScope: MTLBarrierScopeBuffers ];
5612+                         }
55965613
55975614                        //  reduce the results from the workgroups
55985615                        {
@@ -5856,8 +5873,14 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
58565873            }
58575874    }
58585875
5876+     if  (ctx_dev->debug_graph  > 0 ) {
5877+         if  (n_fuse > 1 ) {
5878+             GGML_LOG_INFO (" %s : fuse: %d  ops\n "  , __func__, n_fuse);
5879+         }
5880+     }
5881+ 
58595882    //  after fusing, we have to add the new destination range to the encoding context
5860-     if  (n_fuse > 1 ) {
5883+     if  (ctx_dev-> use_concurrency  &&  n_fuse > 1 ) {
58615884        struct  ggml_tensor * dstf = nodes[n_fuse - 1 ];
58625885
58635886        size_t  offs_dstf = 0 ;
@@ -6743,7 +6766,15 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
67436766
67446767        ggml_metal_mem_pool_reset (mem_pool);
67456768
6746-         id <MTLComputeCommandEncoder > encoder = [cmd_buf computeCommandEncoderWithDispatchType:  MTLDispatchTypeConcurrent ];
6769+         id <MTLComputeCommandEncoder > encoder;
6770+ 
6771+         struct  ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
6772+ 
6773+         if  (ctx_dev->use_concurrency ) {
6774+             encoder = [cmd_buf computeCommandEncoderWithDispatchType:  MTLDispatchTypeConcurrent ];
6775+         } else  {
6776+             encoder = [cmd_buf computeCommandEncoder ];
6777+         }
67476778
67486779        int  node_start = 0 ;
67496780        int  node_end   = n_nodes_0;
0 commit comments