Skip to content

Commit 7dc9199

Browse files
committed
cont : refactor and handle fusing
ggml-ci
1 parent b6fb92f commit 7dc9199

File tree

1 file changed

+182
-105
lines changed

1 file changed

+182
-105
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 182 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -2092,47 +2092,138 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
20922092
int pt; // type: 0 - src, 1 - dst
20932093
} ranges[MEM_RANGE_MAX];
20942094

2095+
int debug;
20952096
};
20962097

2097-
static bool ggml_metal_encode_reset_mem_ranges(struct ggml_metal_encode_context * ctx_enc) {
2098-
ctx_enc->n_ranges = 0;
2098+
static bool ggml_metal_encode_mem_ranges_reset(struct ggml_metal_encode_context * ctx) {
2099+
[ctx->encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
2100+
2101+
ctx->n_ranges = 0;
20992102

21002103
return true;
21012104
}
21022105

2103-
static bool ggml_metal_encode_add_mem_range(struct ggml_metal_encode_context * ctx_enc, struct mem_range r) {
2104-
if (ctx_enc->n_ranges == MEM_RANGE_MAX) {
2106+
static bool ggml_metal_encode_mem_ranges_add(struct ggml_metal_encode_context * ctx, struct mem_range r) {
2107+
if (ctx->n_ranges == MEM_RANGE_MAX) {
21052108
return false;
21062109
}
21072110

2108-
ctx_enc->ranges[ctx_enc->n_ranges] = r;
2111+
ctx->ranges[ctx->n_ranges] = r;
21092112

2110-
ctx_enc->n_ranges++;
2113+
ctx->n_ranges++;
21112114

21122115
return true;
21132116
}
21142117

2115-
// check and return true:
2116-
// - if new range overlaps with any existing range of a different type
2117-
// - if we are close to running out of range cells
2118-
static bool ggml_metal_encode_check_mem_range(struct ggml_metal_encode_context * ctx_enc, struct mem_range r) {
2119-
if (ctx_enc->n_ranges + 2*GGML_MAX_SRC >= MEM_RANGE_MAX) {
2118+
static bool ggml_metal_encode_mem_ranges_add_src(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
2119+
if (!node) {
21202120
return true;
21212121
}
21222122

2123-
for (int i = 0; i < ctx_enc->n_ranges; i++) {
2124-
if (ctx_enc->ranges[i].pt == r.pt) {
2123+
size_t offs = 0;
2124+
id<MTLBuffer> id_node = ggml_metal_get_buffer(node, &offs);
2125+
GGML_ASSERT(id_node != nil);
2126+
2127+
struct mem_range r = {
2128+
/*.p0 =*/ id_node.gpuAddress + offs,
2129+
/*.p1 =*/ id_node.gpuAddress + offs + ggml_nbytes(node),
2130+
/*.pt =*/ 0,
2131+
};
2132+
2133+
if (ctx->debug > 2) {
2134+
GGML_LOG_DEBUG("%s: add src range [%lld, %lld)\n", __func__, r.p0, r.p1);
2135+
}
2136+
2137+
return ggml_metal_encode_mem_ranges_add(ctx, r);
2138+
}
2139+
2140+
static bool ggml_metal_encode_mem_ranges_add_dst(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
2141+
GGML_ASSERT(node);
2142+
2143+
size_t offs = 0;
2144+
id<MTLBuffer> id_node = ggml_metal_get_buffer(node, &offs);
2145+
GGML_ASSERT(id_node != nil);
2146+
2147+
struct mem_range r = {
2148+
/*.p0 =*/ id_node.gpuAddress + offs,
2149+
/*.p1 =*/ id_node.gpuAddress + offs + ggml_nbytes(node),
2150+
/*.pt =*/ 1,
2151+
};
2152+
2153+
if (ctx->debug > 2) {
2154+
GGML_LOG_DEBUG("%s: add dst range [%lld, %lld)\n", __func__, r.p0, r.p1);
2155+
}
2156+
2157+
return ggml_metal_encode_mem_ranges_add(ctx, r);
2158+
}
2159+
2160+
// return true if:
2161+
// - new src range overlaps with any existing dst range
2162+
// - new dst range overlaps with any existing range (src or dst)
2163+
static bool ggml_metal_encode_mem_ranges_check(const struct ggml_metal_encode_context * ctx, struct mem_range r) {
2164+
for (int i = 0; i < ctx->n_ranges; i++) {
2165+
if (r.pt == 0 && ctx->ranges[i].pt == 0) {
21252166
continue;
21262167
}
21272168

2128-
if (r.p0 < ctx_enc->ranges[i].p1 && r.p1 > ctx_enc->ranges[i].p0) {
2169+
if (r.p0 < ctx->ranges[i].p1 && r.p1 > ctx->ranges[i].p0) {
21292170
return true;
21302171
}
21312172
}
21322173

21332174
return false;
21342175
}
21352176

2177+
static bool ggml_metal_encode_mem_ranges_check_src(const struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
2178+
if (!node) {
2179+
return false;
2180+
}
2181+
2182+
size_t offs = 0;
2183+
id<MTLBuffer> id_node = ggml_metal_get_buffer(node, &offs);
2184+
GGML_ASSERT(id_node != nil);
2185+
2186+
struct mem_range r = {
2187+
/*.p0 =*/ id_node.gpuAddress + offs,
2188+
/*.p1 =*/ id_node.gpuAddress + offs + ggml_nbytes(node),
2189+
/*.pt =*/ 0,
2190+
};
2191+
2192+
const bool res = ggml_metal_encode_mem_ranges_check(ctx, r);
2193+
2194+
if (res) {
2195+
if (ctx->debug > 2) {
2196+
GGML_LOG_DEBUG("%s: the src range [%lld, %lld) overlaps with a previous dst range\n", __func__, r.p0, r.p1);
2197+
}
2198+
}
2199+
2200+
return res;
2201+
}
2202+
2203+
static bool ggml_metal_encode_mem_ranges_check_dst(const struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
2204+
GGML_ASSERT(node);
2205+
2206+
size_t offs = 0;
2207+
id<MTLBuffer> id_node = ggml_metal_get_buffer(node, &offs);
2208+
GGML_ASSERT(id_node != nil);
2209+
2210+
struct mem_range r = {
2211+
/*.p0 =*/ id_node.gpuAddress + offs,
2212+
/*.p1 =*/ id_node.gpuAddress + offs + ggml_nbytes(node),
2213+
/*.pt =*/ 1,
2214+
};
2215+
2216+
const bool res = ggml_metal_encode_mem_ranges_check(ctx, r);
2217+
2218+
if (res) {
2219+
if (ctx->debug > 2) {
2220+
GGML_LOG_DEBUG("%s: the dst range [%lld, %lld) overlaps with a previous src range\n", __func__, r.p0, r.p1);
2221+
}
2222+
}
2223+
2224+
return res;
2225+
}
2226+
21362227
static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
21372228
ggml_backend_t backend = ctx_enc->backend;
21382229

@@ -2254,94 +2345,47 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
22542345

22552346
int n_fuse = 1;
22562347

2348+
if (ctx_dev->debug_graph > 0) {
2349+
GGML_LOG_DEBUG("%s: op - %s\n", __func__, ggml_op_name(dst->op));
2350+
if (src0) {
2351+
GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
2352+
ggml_is_contiguous(src0), src0->name);
2353+
}
2354+
if (src1) {
2355+
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
2356+
ggml_is_contiguous(src1), src1->name);
2357+
}
2358+
if (dst) {
2359+
GGML_LOG_DEBUG("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
2360+
dst->name);
2361+
}
2362+
}
2363+
22572364
// check if the current node can run concurrently with other nodes before it
22582365
// the condition is that:
2259-
// - the current node cannot write to any previous src ranges
2366+
// - the current node cannot write to any previous src or dst ranges
22602367
// - the current node cannot read from any previous dst ranges
22612368
//
22622369
// if the condition is not satisfied, we put a memory barrier and clear all ranges
2263-
// otherwise, we add the new ranges to the encoding context and add the node for concurrent execution
2370+
// otherwise, we add the new ranges to the encoding context and process the node concurrently
22642371
//
2265-
bool is_concurrent = ctx_dev->use_concurrency;
2372+
if (ctx_dev->use_concurrency) {
2373+
bool is_concurrent = true;
22662374

2267-
if (is_concurrent) {
22682375
// do not read from any previous dst ranges
22692376
for (int i = 0; i < GGML_MAX_SRC; i++) {
2270-
if (id_src[i] == nil) {
2271-
continue;
2272-
}
2273-
2274-
struct mem_range r = {
2275-
/*.p0 =*/ id_src[i].gpuAddress + offs_src[i],
2276-
/*.p1 =*/ id_src[i].gpuAddress + offs_src[i] + ggml_nbytes(node->src[i]),
2277-
/*.pt =*/ 0,
2278-
};
2279-
2280-
if (ggml_metal_encode_check_mem_range(ctx_enc, r)) {
2281-
is_concurrent = false;
2282-
2283-
break;
2284-
}
2377+
is_concurrent = is_concurrent && !ggml_metal_encode_mem_ranges_check_src(ctx_enc, node->src[i]);
22852378
}
22862379

2287-
// do not write to any previous src ranges
2288-
if (is_concurrent) {
2289-
struct mem_range r = {
2290-
/*.p0 =*/ id_dst.gpuAddress + offs_dst,
2291-
/*.p1 =*/ id_dst.gpuAddress + offs_dst + ggml_nbytes(dst),
2292-
/*.pt =*/ 1,
2293-
};
2294-
2295-
if (ggml_metal_encode_check_mem_range(ctx_enc, r)) {
2296-
is_concurrent = false;
2297-
}
2298-
}
2380+
// do not write to any previous ranges
2381+
is_concurrent = is_concurrent && !ggml_metal_encode_mem_ranges_check_dst(ctx_enc, dst);
22992382

23002383
if (!is_concurrent) {
2301-
ggml_metal_encode_reset_mem_ranges(ctx_enc);
2302-
2303-
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
2384+
ggml_metal_encode_mem_ranges_reset(ctx_enc);
23042385
}
23052386

2306-
// add new ranges
2307-
for (int i = 0; i < GGML_MAX_SRC; i++) {
2308-
if (id_src[i] == nil) {
2309-
continue;
2310-
}
2311-
2312-
struct mem_range r = {
2313-
/*.p0 =*/ id_src[i].gpuAddress + offs_src[i],
2314-
/*.p1 =*/ id_src[i].gpuAddress + offs_src[i] + ggml_nbytes(node->src[i]),
2315-
/*.pt =*/ 0,
2316-
};
2317-
2318-
ggml_metal_encode_add_mem_range(ctx_enc, r);
2319-
}
2320-
2321-
{
2322-
struct mem_range r = {
2323-
/*.p0 =*/ id_dst.gpuAddress + offs_dst,
2324-
/*.p1 =*/ id_dst.gpuAddress + offs_dst + ggml_nbytes(dst),
2325-
/*.pt =*/ 1,
2326-
};
2327-
2328-
ggml_metal_encode_add_mem_range(ctx_enc, r);
2329-
}
2330-
}
2331-
2332-
if (ctx_dev->debug_graph > 0) {
2333-
GGML_LOG_INFO("%s: op - %s, concurrent = %d\n", __func__, ggml_op_name(dst->op), is_concurrent);
2334-
if (src0) {
2335-
GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
2336-
ggml_is_contiguous(src0), src0->name);
2337-
}
2338-
if (src1) {
2339-
GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
2340-
ggml_is_contiguous(src1), src1->name);
2341-
}
2342-
if (dst) {
2343-
GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
2344-
dst->name);
2387+
if (ctx_dev->debug_graph > 0) {
2388+
GGML_LOG_DEBUG("%s: concurrent = %d\n", __func__, is_concurrent);
23452389
}
23462390
}
23472391

@@ -2544,6 +2588,22 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
25442588
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
25452589
}
25462590

2591+
if (ctx_dev->use_concurrency && n_fuse > 1) {
2592+
bool is_concurrent = true;
2593+
2594+
// make sure that none of the fused nodes reads from a previous dst range
2595+
for (int i = 1; i < n_fuse; ++i) {
2596+
is_concurrent = is_concurrent && !ggml_metal_encode_mem_ranges_check_src(ctx_enc, nodes[i]->src[1]);
2597+
}
2598+
2599+
// do not write to any previous range
2600+
is_concurrent = is_concurrent && !ggml_metal_encode_mem_ranges_check_dst(ctx_enc, nodes[n_fuse - 1]);
2601+
2602+
if (!is_concurrent) {
2603+
ggml_metal_encode_mem_ranges_reset(ctx_enc);
2604+
}
2605+
}
2606+
25472607
[encoder setComputePipelineState:pipeline];
25482608
[encoder setBytes:&args length:sizeof(args) atIndex:0];
25492609
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
@@ -2688,7 +2748,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
26882748
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
26892749

26902750
if (ctx_dev->use_concurrency) {
2691-
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
2751+
ggml_metal_encode_mem_ranges_reset(ctx_enc);
26922752
}
26932753
}
26942754

@@ -4215,7 +4275,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
42154275
}
42164276

42174277
if (ctx_dev->use_concurrency) {
4218-
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
4278+
ggml_metal_encode_mem_ranges_reset(ctx_enc);
42194279
}
42204280

42214281
{
@@ -4688,6 +4748,22 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
46884748
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
46894749
}
46904750

4751+
if (ctx_dev->use_concurrency) {
4752+
bool is_concurrent = true;
4753+
4754+
// make sure that none of the fused nodes reads from a previous dst range
4755+
for (int i = 1; i < n_fuse; ++i) {
4756+
is_concurrent = is_concurrent && !ggml_metal_encode_mem_ranges_check_src(ctx_enc, nodes[i]->src[1]);
4757+
}
4758+
4759+
// do not write to any previous range
4760+
is_concurrent = is_concurrent && !ggml_metal_encode_mem_ranges_check_dst(ctx_enc, nodes[n_fuse - 1]);
4761+
4762+
if (!is_concurrent) {
4763+
ggml_metal_encode_mem_ranges_reset(ctx_enc);
4764+
}
4765+
}
4766+
46914767
id<MTLComputePipelineState> pipeline;
46924768

46934769
switch (n_fuse) {
@@ -5608,7 +5684,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
56085684
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
56095685

56105686
if (ctx_dev->use_concurrency) {
5611-
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
5687+
ggml_metal_encode_mem_ranges_reset(ctx_enc);
56125688
}
56135689

56145690
// reduce the results from the workgroups
@@ -5875,28 +5951,28 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
58755951

58765952
if (ctx_dev->debug_graph > 0) {
58775953
if (n_fuse > 1) {
5878-
GGML_LOG_INFO("%s: fuse: %d ops\n", __func__, n_fuse);
5954+
GGML_LOG_DEBUG("%s: fuse: %d ops\n", __func__, n_fuse);
58795955
}
58805956
}
58815957

5882-
// after fusing, we have to add the new destination range to the encoding context
5883-
if (ctx_dev->use_concurrency && n_fuse > 1) {
5884-
struct ggml_tensor * dstf = nodes[n_fuse - 1];
5885-
5886-
size_t offs_dstf = 0;
5958+
// update the mem ranges in the encoding context
5959+
if (ctx_dev->use_concurrency) {
5960+
bool ok = true;
58875961

5888-
id<MTLBuffer> id_dstf = dstf ? ggml_metal_get_buffer(dstf, &offs_dstf) : nil;
5962+
// add new src ranges
5963+
for (int i = 0; i < GGML_MAX_SRC; i++) {
5964+
ok = ok && ggml_metal_encode_mem_ranges_add_src(ctx_enc, node->src[i]);
5965+
}
58895966

5890-
struct mem_range r = {
5891-
/*.p0 =*/ id_dstf.gpuAddress + offs_dstf,
5892-
/*.p1 =*/ id_dstf.gpuAddress + offs_dstf + ggml_nbytes(dstf),
5893-
/*.pt =*/ 1,
5894-
};
5967+
// add the destination range
5968+
ok = ok && ggml_metal_encode_mem_ranges_add_dst(ctx_enc, nodes[n_fuse - 1]);
58955969

5896-
if (!ggml_metal_encode_add_mem_range(ctx_enc, r)) {
5897-
ggml_metal_encode_reset_mem_ranges(ctx_enc);
5970+
if (!ok) {
5971+
if (ctx_dev->debug_graph > 2) {
5972+
GGML_LOG_DEBUG("%s: the range cache is full -> reset and put a barrier\n", __func__);
5973+
}
58985974

5899-
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
5975+
ggml_metal_encode_mem_ranges_reset(ctx_enc);
59005976
}
59015977
}
59025978

@@ -6792,6 +6868,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
67926868
/*.mem_pool =*/ mem_pool,
67936869
/*.n_ranges =*/ 0,
67946870
/*.ranges =*/ { 0 },
6871+
/*.debug =*/ ctx_dev->debug_graph,
67956872
};
67966873

67976874
for (int idx = node_start; idx < node_end;) {

0 commit comments

Comments
 (0)