Skip to content

Commit bed2175

Browse files
committed
cont : prepare mem ranges for reuse + add ggml-metal-common.cpp
ggml-ci
1 parent 589ab9c commit bed2175

File tree

4 files changed

+195
-99
lines changed

4 files changed

+195
-99
lines changed

ggml/src/ggml-metal/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ message(STATUS "Metal framework found")
66

77
ggml_add_backend_library(ggml-metal
88
ggml-metal.m
9+
ggml-metal-common.cpp
910
)
1011

1112
target_link_libraries(ggml-metal PRIVATE
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#include "ggml-metal-common.h"
2+
3+
#include "ggml-impl.h"
4+
5+
#include <vector>
6+
7+
// keep this separate from the public ggml_mem_range_params
8+
struct ggml_mem_range {
9+
uint64_t p0; // being
10+
uint64_t p1; // end
11+
12+
enum ggml_mem_range_type pt;
13+
};
14+
15+
struct ggml_mem_ranges {
16+
std::vector<struct ggml_mem_range> ranges;
17+
18+
int debug = 0;
19+
};
20+
21+
struct ggml_mem_ranges * ggml_mem_ranges_init(int debug) {
22+
auto * res = new struct ggml_mem_ranges;
23+
24+
res->debug = debug;
25+
26+
return res;
27+
}
28+
29+
void ggml_mem_ranges_free(struct ggml_mem_ranges * mrs) {
30+
delete mrs;
31+
}
32+
33+
void ggml_mem_ranges_reset(struct ggml_mem_ranges * mrs) {
34+
mrs->ranges.clear();
35+
}
36+
37+
bool ggml_mem_ranges_add(struct ggml_mem_ranges * mrs, struct ggml_mem_range_params mrp) {
38+
mrs->ranges.push_back({
39+
/*.p0 =*/ mrp.p0,
40+
/*.p1 =*/ mrp.p1,
41+
/*.pt =*/ mrp.pt,
42+
});
43+
44+
return true;
45+
}
46+
47+
bool ggml_mem_ranges_add_src(struct ggml_mem_ranges * mrs, const struct ggml_tensor * node) {
48+
GGML_ASSERT(node);
49+
50+
struct ggml_mem_range_params mrp = {
51+
/*.p0 =*/ (uint64_t) node->data,
52+
/*.p1 =*/ (uint64_t) node->data + ggml_nbytes(node),
53+
/*.pt =*/ MEM_RANGE_TYPE_SRC,
54+
};
55+
56+
if (mrs->debug > 2) {
57+
GGML_LOG_DEBUG("%s: add src range [%lld, %lld)\n", __func__, mrp.p0, mrp.p1);
58+
}
59+
60+
return ggml_mem_ranges_add(mrs, mrp);
61+
}
62+
63+
bool ggml_mem_ranges_add_dst(struct ggml_mem_ranges * mrs, const struct ggml_tensor * node) {
64+
GGML_ASSERT(node);
65+
66+
struct ggml_mem_range_params mrp = {
67+
/*.p0 =*/ (uint64_t) node->data,
68+
/*.p1 =*/ (uint64_t) node->data + ggml_nbytes(node),
69+
/*.pt =*/ MEM_RANGE_TYPE_DST,
70+
};
71+
72+
if (mrs->debug > 2) {
73+
GGML_LOG_DEBUG("%s: add dst range [%lld, %lld)\n", __func__, mrp.p0, mrp.p1);
74+
}
75+
76+
return ggml_mem_ranges_add(mrs, mrp);
77+
}
78+
79+
bool ggml_mem_ranges_check(const struct ggml_mem_ranges * mrs, struct ggml_mem_range_params mrp) {
80+
for (size_t i = 0; i < mrs->ranges.size(); i++) {
81+
if (mrp.pt == MEM_RANGE_TYPE_SRC && mrs->ranges[i].pt == MEM_RANGE_TYPE_SRC) {
82+
continue;
83+
}
84+
85+
if (mrp.p0 < mrs->ranges[i].p1 && mrp.p1 > mrs->ranges[i].p0) {
86+
return true;
87+
}
88+
}
89+
90+
return false;
91+
}
92+
93+
bool ggml_mem_ranges_check_src(const struct ggml_mem_ranges * mrs, const struct ggml_tensor * node) {
94+
GGML_ASSERT(node);
95+
96+
struct ggml_mem_range_params mrp = {
97+
/*.p0 =*/ (uint64_t) node->data,
98+
/*.p1 =*/ (uint64_t) node->data + ggml_nbytes(node),
99+
/*.pt =*/ MEM_RANGE_TYPE_SRC,
100+
};
101+
102+
const bool res = ggml_mem_ranges_check(mrs, mrp);
103+
104+
if (res) {
105+
if (mrs->debug > 2) {
106+
GGML_LOG_DEBUG("%s: the src range [%lld, %lld) overlaps with a previous dst range\n", __func__, mrp.p0, mrp.p1);
107+
}
108+
}
109+
110+
return res;
111+
}
112+
113+
bool ggml_mem_ranges_check_dst(const struct ggml_mem_ranges * mrs, const struct ggml_tensor * node) {
114+
GGML_ASSERT(node);
115+
116+
struct ggml_mem_range_params mrp = {
117+
/*.p0 =*/ (uint64_t) node->data,
118+
/*.p1 =*/ (uint64_t) node->data + ggml_nbytes(node),
119+
/*.pt =*/ MEM_RANGE_TYPE_DST,
120+
};
121+
122+
const bool res = ggml_mem_ranges_check(mrs, mrp);
123+
124+
if (res) {
125+
if (mrs->debug > 2) {
126+
GGML_LOG_DEBUG("%s: the dst range [%lld, %lld) overlaps with a previous src range\n", __func__, mrp.p0, mrp.p1);
127+
}
128+
}
129+
130+
return res;
131+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#pragma once
2+
3+
#include <stdint.h>
4+
#include <stdbool.h>
5+
6+
#ifdef __cplusplus
7+
extern "C" {
8+
#endif
9+
10+
struct ggml_tensor;
11+
12+
enum ggml_mem_range_type {
13+
MEM_RANGE_TYPE_SRC = 0,
14+
MEM_RANGE_TYPE_DST = 1,
15+
};
16+
17+
struct ggml_mem_range_params {
18+
uint64_t p0; // being
19+
uint64_t p1; // end
20+
21+
enum ggml_mem_range_type pt;
22+
};
23+
24+
struct ggml_mem_ranges;
25+
26+
struct ggml_mem_ranges * ggml_mem_ranges_init(int debug);
27+
void ggml_mem_ranges_free(struct ggml_mem_ranges * mrs);
28+
29+
void ggml_mem_ranges_reset(struct ggml_mem_ranges * mrs);
30+
31+
bool ggml_mem_ranges_add(struct ggml_mem_ranges * mrs, struct ggml_mem_range_params mrp);
32+
33+
bool ggml_mem_ranges_add_src(struct ggml_mem_ranges * mrs, const struct ggml_tensor * node);
34+
bool ggml_mem_ranges_add_dst(struct ggml_mem_ranges * mrs, const struct ggml_tensor * node);
35+
36+
// return true if:
37+
// - new src range overlaps with any existing dst range
38+
// - new dst range overlaps with any existing range (src or dst)
39+
bool ggml_mem_ranges_check(const struct ggml_mem_ranges * mrs, struct ggml_mem_range_params mrp);
40+
41+
bool ggml_mem_ranges_check_src(const struct ggml_mem_ranges * mrs, const struct ggml_tensor * node);
42+
bool ggml_mem_ranges_check_dst(const struct ggml_mem_ranges * mrs, const struct ggml_tensor * node);
43+
44+
#ifdef __cplusplus
45+
}
46+
#endif

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 17 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#import "ggml-impl.h"
44
#import "ggml-backend-impl.h"
55
#import "ggml-metal-impl.h"
6+
#import "ggml-metal-common.h"
67

78
#import <Foundation/Foundation.h>
89

@@ -2075,42 +2076,20 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
20752076
}
20762077
}
20772078

2078-
#define MEM_RANGE_MAX 128
2079-
20802079
struct ggml_metal_encode_context {
20812080
ggml_backend_t backend;
20822081

20832082
id<MTLComputeCommandEncoder> encoder;
20842083

20852084
struct ggml_metal_mem_pool * mem_pool;
20862085

2087-
int n_ranges;
2088-
2089-
struct mem_range {
2090-
uint64_t p0; // being
2091-
uint64_t p1; // end
2092-
int pt; // type: 0 - src, 1 - dst
2093-
} ranges[MEM_RANGE_MAX];
2094-
2095-
int debug;
2086+
struct ggml_mem_ranges * mem_ranges;
20962087
};
20972088

20982089
static bool ggml_metal_encode_mem_ranges_reset(struct ggml_metal_encode_context * ctx) {
20992090
[ctx->encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
21002091

2101-
ctx->n_ranges = 0;
2102-
2103-
return true;
2104-
}
2105-
2106-
static bool ggml_metal_encode_mem_ranges_add(struct ggml_metal_encode_context * ctx, struct mem_range r) {
2107-
if (ctx->n_ranges == MEM_RANGE_MAX) {
2108-
return false;
2109-
}
2110-
2111-
ctx->ranges[ctx->n_ranges] = r;
2112-
2113-
ctx->n_ranges++;
2092+
ggml_mem_ranges_reset(ctx->mem_ranges);
21142093

21152094
return true;
21162095
}
@@ -2120,92 +2099,27 @@ static bool ggml_metal_encode_mem_ranges_add_src(struct ggml_metal_encode_contex
21202099
return true;
21212100
}
21222101

2123-
struct mem_range r = {
2124-
/*.p0 =*/ (uint64_t) node->data,
2125-
/*.p1 =*/ (uint64_t) node->data + ggml_nbytes(node),
2126-
/*.pt =*/ 0,
2127-
};
2128-
2129-
if (ctx->debug > 2) {
2130-
GGML_LOG_DEBUG("%s: add src range [%lld, %lld)\n", __func__, r.p0, r.p1);
2131-
}
2132-
2133-
return ggml_metal_encode_mem_ranges_add(ctx, r);
2102+
return ggml_mem_ranges_add_src(ctx->mem_ranges, node);
21342103
}
21352104

21362105
static bool ggml_metal_encode_mem_ranges_add_dst(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
21372106
GGML_ASSERT(node);
21382107

2139-
struct mem_range r = {
2140-
/*.p0 =*/ (uint64_t) node->data,
2141-
/*.p1 =*/ (uint64_t) node->data + ggml_nbytes(node),
2142-
/*.pt =*/ 1,
2143-
};
2144-
2145-
if (ctx->debug > 2) {
2146-
GGML_LOG_DEBUG("%s: add dst range [%lld, %lld)\n", __func__, r.p0, r.p1);
2147-
}
2148-
2149-
return ggml_metal_encode_mem_ranges_add(ctx, r);
2150-
}
2151-
2152-
// return true if:
2153-
// - new src range overlaps with any existing dst range
2154-
// - new dst range overlaps with any existing range (src or dst)
2155-
static bool ggml_metal_encode_mem_ranges_check(const struct ggml_metal_encode_context * ctx, struct mem_range r) {
2156-
for (int i = 0; i < ctx->n_ranges; i++) {
2157-
if (r.pt == 0 && ctx->ranges[i].pt == 0) {
2158-
continue;
2159-
}
2160-
2161-
if (r.p0 < ctx->ranges[i].p1 && r.p1 > ctx->ranges[i].p0) {
2162-
return true;
2163-
}
2164-
}
2165-
2166-
return false;
2108+
return ggml_mem_ranges_add_dst(ctx->mem_ranges, node);
21672109
}
21682110

21692111
static bool ggml_metal_encode_mem_ranges_check_src(const struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
21702112
if (!node) {
21712113
return false;
21722114
}
21732115

2174-
struct mem_range r = {
2175-
/*.p0 =*/ (uint64_t) node->data,
2176-
/*.p1 =*/ (uint64_t) node->data + ggml_nbytes(node),
2177-
/*.pt =*/ 0,
2178-
};
2179-
2180-
const bool res = ggml_metal_encode_mem_ranges_check(ctx, r);
2181-
2182-
if (res) {
2183-
if (ctx->debug > 2) {
2184-
GGML_LOG_DEBUG("%s: the src range [%lld, %lld) overlaps with a previous dst range\n", __func__, r.p0, r.p1);
2185-
}
2186-
}
2187-
2188-
return res;
2116+
return ggml_mem_ranges_check_src(ctx->mem_ranges, node);
21892117
}
21902118

21912119
static bool ggml_metal_encode_mem_ranges_check_dst(const struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
21922120
GGML_ASSERT(node);
21932121

2194-
struct mem_range r = {
2195-
/*.p0 =*/ (uint64_t) node->data,
2196-
/*.p1 =*/ (uint64_t) node->data + ggml_nbytes(node),
2197-
/*.pt =*/ 1,
2198-
};
2199-
2200-
const bool res = ggml_metal_encode_mem_ranges_check(ctx, r);
2201-
2202-
if (res) {
2203-
if (ctx->debug > 2) {
2204-
GGML_LOG_DEBUG("%s: the dst range [%lld, %lld) overlaps with a previous src range\n", __func__, r.p0, r.p1);
2205-
}
2206-
}
2207-
2208-
return res;
2122+
return ggml_mem_ranges_check_dst(ctx->mem_ranges, node);
22092123
}
22102124

22112125
static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
@@ -6847,14 +6761,16 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
68476761
const bool should_capture = ctx->capture_next_compute;
68486762

68496763
struct ggml_metal_encode_context ctx_enc = {
6850-
/*.backend =*/ backend,
6851-
/*.encoder =*/ encoder,
6852-
/*.mem_pool =*/ mem_pool,
6853-
/*.n_ranges =*/ 0,
6854-
/*.ranges =*/ { 0 },
6855-
/*.debug =*/ ctx_dev->debug_graph,
6764+
/*.backend =*/ backend,
6765+
/*.encoder =*/ encoder,
6766+
/*.mem_pool =*/ mem_pool,
6767+
/*.mem_ranges =*/ NULL,
68566768
};
68576769

6770+
if (ctx_dev->use_concurrency) {
6771+
ctx_enc.mem_ranges = ggml_mem_ranges_init(ctx_dev->debug_graph);
6772+
}
6773+
68586774
for (int idx = node_start; idx < node_end;) {
68596775
if (should_capture) {
68606776
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
@@ -6879,6 +6795,8 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
68796795

68806796
[encoder endEncoding];
68816797

6798+
ggml_mem_ranges_free(ctx_enc.mem_ranges);
6799+
68826800
if (cb_idx < 2 || ctx->abort_callback == NULL) {
68836801
[cmd_buf commit];
68846802
}

0 commit comments

Comments
 (0)