Skip to content

Commit ea7ce2d

Browse files
committed
metal : run graphs ops concurrently
ggml-ci
1 parent 9de447d commit ea7ce2d

File tree

1 file changed

+205
-21
lines changed

1 file changed

+205
-21
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 205 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,12 +2063,71 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
20632063
}
20642064
}
20652065

2066-
static int ggml_metal_encode_node(
2067-
ggml_backend_t backend,
2068-
int idx,
2069-
int idx_end,
2070-
id<MTLComputeCommandEncoder> encoder,
2071-
struct ggml_metal_mem_pool * mem_pool) {
2066+
#define MEM_RANGE_MAX 128
2067+
2068+
struct ggml_metal_encode_context {
2069+
ggml_backend_t backend;
2070+
2071+
id<MTLComputeCommandEncoder> encoder;
2072+
2073+
struct ggml_metal_mem_pool * mem_pool;
2074+
2075+
int n_ranges;
2076+
2077+
struct mem_range {
2078+
uint64_t p0; // being
2079+
uint64_t p1; // end
2080+
int pt; // type: 0 - src, 1 - dst
2081+
} ranges[MEM_RANGE_MAX];
2082+
2083+
};
2084+
2085+
static bool ggml_metal_encode_reset_mem_ranges(struct ggml_metal_encode_context * ctx_enc) {
2086+
ctx_enc->n_ranges = 0;
2087+
2088+
return true;
2089+
}
2090+
2091+
static bool ggml_metal_encode_add_mem_range(struct ggml_metal_encode_context * ctx_enc, struct mem_range r) {
2092+
if (ctx_enc->n_ranges == MEM_RANGE_MAX) {
2093+
return false;
2094+
}
2095+
2096+
ctx_enc->ranges[ctx_enc->n_ranges] = r;
2097+
2098+
ctx_enc->n_ranges++;
2099+
2100+
return true;
2101+
}
2102+
2103+
// check and return true:
2104+
// - if new range overlaps with any existing range of a different type
2105+
// - if we are close to running out of range cells
2106+
static bool ggml_metal_encode_check_mem_range(struct ggml_metal_encode_context * ctx_enc, struct mem_range r) {
2107+
if (ctx_enc->n_ranges + 2*GGML_MAX_SRC >= MEM_RANGE_MAX) {
2108+
return true;
2109+
}
2110+
2111+
for (int i = 0; i < ctx_enc->n_ranges; i++) {
2112+
if (ctx_enc->ranges[i].pt == r.pt) {
2113+
continue;
2114+
}
2115+
2116+
if (r.p0 < ctx_enc->ranges[i].p1 && r.p1 > ctx_enc->ranges[i].p0) {
2117+
return true;
2118+
}
2119+
}
2120+
2121+
return false;
2122+
}
2123+
2124+
static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
2125+
ggml_backend_t backend = ctx_enc->backend;
2126+
2127+
id<MTLComputeCommandEncoder> encoder = ctx_enc->encoder;
2128+
2129+
struct ggml_metal_mem_pool * mem_pool = ctx_enc->mem_pool;
2130+
20722131
struct ggml_backend_metal_context * ctx = backend->context;
20732132
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
20742133

@@ -2151,22 +2210,113 @@ static int ggml_metal_encode_node(
21512210
const uint64_t nb2 = dst ? dst->nb[2] : 0;
21522211
const uint64_t nb3 = dst ? dst->nb[3] : 0;
21532212

2213+
size_t offs_src[GGML_MAX_SRC];
2214+
2215+
id<MTLBuffer> id_src[GGML_MAX_SRC];
2216+
2217+
enum ggml_type srct[GGML_MAX_SRC];
2218+
2219+
for (int i = 0; i < GGML_MAX_SRC; i++) {
2220+
offs_src[i] = 0;
2221+
id_src[i] = node->src[i] ? ggml_metal_get_buffer(node->src[i], &offs_src[i]) : nil;
2222+
srct[i] = node->src[i] ? node->src[i]->type : GGML_TYPE_COUNT;
2223+
}
2224+
2225+
size_t offs_dst = 0;
2226+
2227+
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
2228+
2229+
int n_fuse = 1;
2230+
2231+
// check if the current node can run concurrently with other nodes before it
2232+
// the condition is that:
2233+
// - the current node cannot write to any previous src ranges
2234+
// - the current node cannot read from any previous dst ranges
2235+
//
2236+
// if the condition is not satisfied, we put a memory barrier and clear all ranges
2237+
// otherwise, we add the new ranges to the encoding context and add the node for concurrent execution
2238+
//
2239+
{
2240+
bool is_concurrent = true;
2241+
2242+
// do not read from any previous dst ranges
2243+
for (int i = 0; i < GGML_MAX_SRC; i++) {
2244+
if (id_src[i] == nil) {
2245+
continue;
2246+
}
2247+
2248+
struct mem_range r = {
2249+
/*.p0 =*/ id_src[i].gpuAddress + offs_src[i],
2250+
/*.p1 =*/ id_src[i].gpuAddress + offs_src[i] + ggml_nbytes(node->src[i]),
2251+
/*.pt =*/ 0,
2252+
};
2253+
2254+
if (ggml_metal_encode_check_mem_range(ctx_enc, r)) {
2255+
is_concurrent = false;
2256+
2257+
break;
2258+
}
2259+
}
2260+
2261+
// do not write to any previous src ranges
2262+
if (is_concurrent) {
2263+
struct mem_range r = {
2264+
/*.p0 =*/ id_dst.gpuAddress + offs_dst,
2265+
/*.p1 =*/ id_dst.gpuAddress + offs_dst + ggml_nbytes(dst),
2266+
/*.pt =*/ 1,
2267+
};
2268+
2269+
if (ggml_metal_encode_check_mem_range(ctx_enc, r)) {
2270+
is_concurrent = false;
2271+
}
2272+
}
2273+
2274+
if (!is_concurrent) {
2275+
ggml_metal_encode_reset_mem_ranges(ctx_enc);
2276+
2277+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
2278+
}
2279+
2280+
// add new ranges
2281+
for (int i = 0; i < GGML_MAX_SRC; i++) {
2282+
if (id_src[i] == nil) {
2283+
continue;
2284+
}
2285+
2286+
struct mem_range r = {
2287+
/*.p0 =*/ id_src[i].gpuAddress + offs_src[i],
2288+
/*.p1 =*/ id_src[i].gpuAddress + offs_src[i] + ggml_nbytes(node->src[i]),
2289+
/*.pt =*/ 0,
2290+
};
2291+
2292+
ggml_metal_encode_add_mem_range(ctx_enc, r);
2293+
}
2294+
2295+
{
2296+
struct mem_range r = {
2297+
/*.p0 =*/ id_dst.gpuAddress + offs_dst,
2298+
/*.p1 =*/ id_dst.gpuAddress + offs_dst + ggml_nbytes(dst),
2299+
/*.pt =*/ 1,
2300+
};
2301+
2302+
ggml_metal_encode_add_mem_range(ctx_enc, r);
2303+
}
2304+
}
2305+
2306+
// TODO: tmp shorthands - remove
2307+
size_t offs_src0 = offs_src[0];
2308+
size_t offs_src1 = offs_src[1];
2309+
size_t offs_src2 = offs_src[2];
2310+
2311+
id<MTLBuffer> id_src0 = id_src[0];
2312+
id<MTLBuffer> id_src1 = id_src[1];
2313+
id<MTLBuffer> id_src2 = id_src[2];
2314+
21542315
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
21552316
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
21562317
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
21572318
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
21582319

2159-
size_t offs_src0 = 0;
2160-
size_t offs_src1 = 0;
2161-
size_t offs_src2 = 0;
2162-
size_t offs_dst = 0;
2163-
2164-
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
2165-
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
2166-
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
2167-
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
2168-
2169-
int n_fuse = 1;
21702320

21712321
#if 0
21722322
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
@@ -2525,6 +2675,8 @@ static int ggml_metal_encode_node(
25252675
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
25262676

25272677
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2678+
2679+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
25282680
}
25292681

25302682
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
@@ -4049,6 +4201,8 @@ static int ggml_metal_encode_node(
40494201
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
40504202
}
40514203

4204+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
4205+
40524206
{
40534207
id<MTLComputePipelineState> pipeline = nil;
40544208

@@ -4660,7 +4814,6 @@ static int ggml_metal_encode_node(
46604814
} break;
46614815
case GGML_OP_ROPE:
46624816
{
4663-
46644817
// make sure we have one or more position id(ne10) per token(ne02)
46654818
GGML_ASSERT(ne10 % ne02 == 0);
46664819
GGML_ASSERT(ne10 >= ne02);
@@ -5439,6 +5592,8 @@ static int ggml_metal_encode_node(
54395592
[encoder setThreadgroupMemoryLength:smem atIndex:0];
54405593
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
54415594

5595+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
5596+
54425597
// reduce the results from the workgroups
54435598
{
54445599
ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
@@ -5669,7 +5824,7 @@ static int ggml_metal_encode_node(
56695824

56705825
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
56715826
} break;
5672-
case GGML_OP_ARGMAX:
5827+
case GGML_OP_ARGMAX:
56735828
{
56745829
GGML_ASSERT(src0->type == GGML_TYPE_F32);
56755830
GGML_ASSERT(ggml_is_contiguous_1(src0));
@@ -5701,6 +5856,27 @@ static int ggml_metal_encode_node(
57015856
}
57025857
}
57035858

5859+
// after fusing, we have to add the new destination range to the encoding context
5860+
if (n_fuse > 1) {
5861+
struct ggml_tensor * dstf = nodes[n_fuse - 1];
5862+
5863+
size_t offs_dstf = 0;
5864+
5865+
id<MTLBuffer> id_dstf = dstf ? ggml_metal_get_buffer(dstf, &offs_dstf) : nil;
5866+
5867+
struct mem_range r = {
5868+
/*.p0 =*/ id_dstf.gpuAddress + offs_dstf,
5869+
/*.p1 =*/ id_dstf.gpuAddress + offs_dstf + ggml_nbytes(dstf),
5870+
/*.pt =*/ 1,
5871+
};
5872+
5873+
if (!ggml_metal_encode_add_mem_range(ctx_enc, r)) {
5874+
ggml_metal_encode_reset_mem_ranges(ctx_enc);
5875+
5876+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
5877+
}
5878+
}
5879+
57045880
return n_fuse;
57055881
}
57065882

@@ -6567,7 +6743,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
65676743

65686744
ggml_metal_mem_pool_reset(mem_pool);
65696745

6570-
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
6746+
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
65716747

65726748
int node_start = 0;
65736749
int node_end = n_nodes_0;
@@ -6579,12 +6755,20 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
65796755

65806756
const bool should_capture = ctx->capture_next_compute;
65816757

6758+
struct ggml_metal_encode_context ctx_enc = {
6759+
/*.backend =*/ backend,
6760+
/*.encoder =*/ encoder,
6761+
/*.mem_pool =*/ mem_pool,
6762+
/*.n_ranges =*/ 0,
6763+
/*.ranges =*/ { 0 },
6764+
};
6765+
65826766
for (int idx = node_start; idx < node_end;) {
65836767
if (should_capture) {
65846768
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
65856769
}
65866770

6587-
const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
6771+
const int res = ggml_metal_encode_node(&ctx_enc, idx, node_end);
65886772
if (idx + res > node_end) {
65896773
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
65906774
"https://github.com/ggml-org/llama.cpp/pull/14849");

0 commit comments

Comments
 (0)