Skip to content

Commit b6fb92f

Browse files
committed
cont : add flags for debugging and disabling concurrency
ggml-ci
1 parent ea7ce2d commit b6fb92f

File tree

1 file changed

+68
-37
lines changed

1 file changed

+68
-37
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 68 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@
6161
bool has_bfloat;
6262
bool use_bfloat;
6363
bool use_fusion;
64+
bool use_concurrency;
6465
bool use_shared_buffers;
6566

67+
int debug_graph;
6668
int debug_fusion;
6769

6870
// how many times a given op was fused
@@ -83,7 +85,9 @@
8385
/*.has_bfloat =*/ false,
8486
/*.use_bfloat =*/ false,
8587
/*.use_fusion =*/ true,
88+
/*.use_concurrency =*/ true,
8689
/*.use_shared_buffers =*/ true,
90+
/*.debug_graph =*/ 0,
8791
/*.debug_fusion =*/ 0,
8892
/*.fuse_cnt =*/ { 0 },
8993
/*.max_size =*/ 0,
@@ -124,7 +128,14 @@
124128
#else
125129
ctx->use_bfloat = false;
126130
#endif
127-
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
131+
132+
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
133+
ctx->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
134+
135+
{
136+
const char * val = getenv("GGML_METAL_GRAPH_DEBUG");
137+
ctx->debug_graph = val ? atoi(val) : 0;
138+
}
128139

129140
{
130141
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
@@ -1091,6 +1102,7 @@ @implementation GGMLMetalClass
10911102
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
10921103
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
10931104
GGML_LOG_INFO("%s: use fusion = %s\n", __func__, ctx_dev->use_fusion ? "true" : "false");
1105+
GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, ctx_dev->use_concurrency ? "true" : "false");
10941106
GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, ctx_dev->use_shared_buffers ? "true" : "false");
10951107
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
10961108

@@ -2222,6 +2234,20 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
22222234
srct[i] = node->src[i] ? node->src[i]->type : GGML_TYPE_COUNT;
22232235
}
22242236

2237+
// TODO: tmp shorthands - remove
2238+
size_t offs_src0 = offs_src[0];
2239+
size_t offs_src1 = offs_src[1];
2240+
size_t offs_src2 = offs_src[2];
2241+
2242+
id<MTLBuffer> id_src0 = id_src[0];
2243+
id<MTLBuffer> id_src1 = id_src[1];
2244+
id<MTLBuffer> id_src2 = id_src[2];
2245+
2246+
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
2247+
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
2248+
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
2249+
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
2250+
22252251
size_t offs_dst = 0;
22262252

22272253
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
@@ -2236,9 +2262,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
22362262
// if the condition is not satisfied, we put a memory barrier and clear all ranges
22372263
// otherwise, we add the new ranges to the encoding context and add the node for concurrent execution
22382264
//
2239-
{
2240-
bool is_concurrent = true;
2265+
bool is_concurrent = ctx_dev->use_concurrency;
22412266

2267+
if (is_concurrent) {
22422268
// do not read from any previous dst ranges
22432269
for (int i = 0; i < GGML_MAX_SRC; i++) {
22442270
if (id_src[i] == nil) {
@@ -2303,36 +2329,21 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
23032329
}
23042330
}
23052331

2306-
// TODO: tmp shorthands - remove
2307-
size_t offs_src0 = offs_src[0];
2308-
size_t offs_src1 = offs_src[1];
2309-
size_t offs_src2 = offs_src[2];
2310-
2311-
id<MTLBuffer> id_src0 = id_src[0];
2312-
id<MTLBuffer> id_src1 = id_src[1];
2313-
id<MTLBuffer> id_src2 = id_src[2];
2314-
2315-
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
2316-
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
2317-
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
2318-
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
2319-
2320-
2321-
#if 0
2322-
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
2323-
if (src0) {
2324-
GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
2325-
ggml_is_contiguous(src0), src0->name);
2326-
}
2327-
if (src1) {
2328-
GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
2329-
ggml_is_contiguous(src1), src1->name);
2330-
}
2331-
if (dst) {
2332-
GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
2333-
dst->name);
2332+
if (ctx_dev->debug_graph > 0) {
2333+
GGML_LOG_INFO("%s: op - %s, concurrent = %d\n", __func__, ggml_op_name(dst->op), is_concurrent);
2334+
if (src0) {
2335+
GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
2336+
ggml_is_contiguous(src0), src0->name);
2337+
}
2338+
if (src1) {
2339+
GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
2340+
ggml_is_contiguous(src1), src1->name);
2341+
}
2342+
if (dst) {
2343+
GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
2344+
dst->name);
2345+
}
23342346
}
2335-
#endif
23362347

23372348
id<MTLDevice> device = ctx_dev->mtl_device;
23382349

@@ -2676,7 +2687,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
26762687

26772688
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
26782689

2679-
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
2690+
if (ctx_dev->use_concurrency) {
2691+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
2692+
}
26802693
}
26812694

26822695
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
@@ -4201,7 +4214,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
42014214
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
42024215
}
42034216

4204-
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
4217+
if (ctx_dev->use_concurrency) {
4218+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
4219+
}
42054220

42064221
{
42074222
id<MTLComputePipelineState> pipeline = nil;
@@ -5592,7 +5607,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55925607
[encoder setThreadgroupMemoryLength:smem atIndex:0];
55935608
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
55945609

5595-
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
5610+
if (ctx_dev->use_concurrency) {
5611+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
5612+
}
55965613

55975614
// reduce the results from the workgroups
55985615
{
@@ -5856,8 +5873,14 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
58565873
}
58575874
}
58585875

5876+
if (ctx_dev->debug_graph > 0) {
5877+
if (n_fuse > 1) {
5878+
GGML_LOG_INFO("%s: fuse: %d ops\n", __func__, n_fuse);
5879+
}
5880+
}
5881+
58595882
// after fusing, we have to add the new destination range to the encoding context
5860-
if (n_fuse > 1) {
5883+
if (ctx_dev->use_concurrency && n_fuse > 1) {
58615884
struct ggml_tensor * dstf = nodes[n_fuse - 1];
58625885

58635886
size_t offs_dstf = 0;
@@ -6743,7 +6766,15 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
67436766

67446767
ggml_metal_mem_pool_reset(mem_pool);
67456768

6746-
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
6769+
id<MTLComputeCommandEncoder> encoder;
6770+
6771+
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
6772+
6773+
if (ctx_dev->use_concurrency) {
6774+
encoder = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
6775+
} else {
6776+
encoder = [cmd_buf computeCommandEncoder];
6777+
}
67476778

67486779
int node_start = 0;
67496780
int node_end = n_nodes_0;

0 commit comments

Comments
 (0)