Skip to content

Commit b71d5f0

Browse files
committed
metal : add memory pool for temp allocs (wip) [no ci]
1 parent d3bd719 commit b71d5f0

File tree

1 file changed

+69
-14
lines changed

1 file changed

+69
-14
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
// note: assumes single GPU device - the default one
4545
// TODO: support multiple GPU devices
4646
static struct ggml_backend_metal_device_context {
47-
id<MTLDevice> mtl_device;
48-
int mtl_device_ref_count;
47+
id<MTLDevice> mtl_device;
48+
int mtl_device_ref_count;
4949
id<MTLLibrary> mtl_library;
5050

5151
bool has_simdgroup_reduction;
@@ -470,6 +470,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
470470

471471
struct ggml_backend_metal_context {
472472
id<MTLCommandQueue> queue;
473+
id<MTLHeap> heap;
473474

474475
dispatch_queue_t d_queue;
475476

@@ -693,6 +694,19 @@ @implementation GGMLMetalClass
693694

694695
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
695696

697+
// allocate tmp heap with fixed size for testing
698+
// TODO: figure out how to dynamically resize it
699+
{
700+
MTLHeapDescriptor *heapDescriptor = [[MTLHeapDescriptor alloc] init];
701+
heapDescriptor.storageMode = MTLStorageModePrivate;
702+
heapDescriptor.cpuCacheMode = MTLCPUCacheModeDefaultCache;
703+
heapDescriptor.size = 32*1024*1024;
704+
705+
ctx->heap = [device newHeapWithDescriptor:heapDescriptor];
706+
707+
[heapDescriptor release];
708+
}
709+
696710
// load library
697711
if (ctx_dev->mtl_library == nil) {
698712
ctx_dev->mtl_library = ggml_metal_load_library(device, ctx_dev->use_bfloat);
@@ -1136,6 +1150,7 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
11361150
Block_release(ctx->encode_async);
11371151

11381152
[ctx->queue release];
1153+
[ctx->heap release];
11391154

11401155
dispatch_release(ctx->d_queue);
11411156

@@ -1438,7 +1453,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
14381453
static void ggml_metal_encode_node(
14391454
ggml_backend_t backend,
14401455
int idx,
1441-
id<MTLComputeCommandEncoder> encoder) {
1456+
id<MTLComputeCommandEncoder> encoder,
1457+
id<MTLHeap> heap) {
14421458
struct ggml_backend_metal_context * ctx = backend->context;
14431459
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
14441460

@@ -2110,26 +2126,65 @@ static void ggml_metal_encode_node(
21102126
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
21112127
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
21122128

2113-
ggml_metal_kargs_soft_max args = {
2129+
// cpy to tmp buffer in MTLHeap
2130+
2131+
ggml_metal_kargs_cpy args_cpy = {
21142132
/*.ne00 =*/ ne00,
21152133
/*.ne01 =*/ ne01,
21162134
/*.ne02 =*/ ne02,
2117-
/*.scale =*/ scale,
2118-
/*.max_bias =*/ max_bias,
2119-
/*.m0 =*/ m0,
2120-
/*.m1 =*/ m1,
2135+
/*.ne03 =*/ ne03,
2136+
/*.nb00 =*/ nb00,
2137+
/*.nb01 =*/ nb01,
2138+
/*.nb02 =*/ nb02,
2139+
/*.nb03 =*/ nb03,
2140+
/*.ne0 =*/ ne00,
2141+
/*.ne1 =*/ ne01,
2142+
/*.ne2 =*/ ne02,
2143+
/*.ne3 =*/ ne03,
2144+
/*.nb0 =*/ nb00,
2145+
/*.nb1 =*/ nb01,
2146+
/*.nb2 =*/ nb02,
2147+
/*.nb3 =*/ nb03,
2148+
};
2149+
2150+
id<MTLBuffer> id_src0h = [heap newBufferWithLength:ggml_nbytes(src0) options:MTLResourceStorageModePrivate];
2151+
2152+
if (src0->type == GGML_TYPE_F16) {
2153+
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
2154+
} else {
2155+
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
2156+
}
2157+
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
2158+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2159+
[encoder setBuffer:id_src0h offset:0 atIndex:2];
2160+
2161+
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
2162+
int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
2163+
2164+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
2165+
2166+
// softmax
2167+
2168+
ggml_metal_kargs_soft_max args = {
2169+
/*.ne00 =*/ ne00,
2170+
/*.ne01 =*/ ne01,
2171+
/*.ne02 =*/ ne02,
2172+
/*.scale =*/ scale,
2173+
/*.max_bias =*/ max_bias,
2174+
/*.m0 =*/ m0,
2175+
/*.m1 =*/ m1,
21212176
/*.n_head_log2 =*/ n_head_log2,
21222177
};
21232178

21242179
[encoder setComputePipelineState:pipeline];
2125-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2180+
[encoder setBuffer:id_src0h offset:0 atIndex:0];
21262181
if (id_src1) {
2127-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2182+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
21282183
} else {
2129-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2184+
[encoder setBuffer:id_src0h offset:0 atIndex:1];
21302185
}
2131-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2132-
[encoder setBytes:&args length:sizeof(args) atIndex:3];
2186+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2187+
[encoder setBytes:&args length:sizeof(args) atIndex:3];
21332188

21342189
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
21352190

@@ -4991,7 +5046,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
49915046
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
49925047
}
49935048

4994-
ggml_metal_encode_node(backend, idx, encoder);
5049+
ggml_metal_encode_node(backend, idx, encoder, ctx->heap);
49955050

49965051
if (should_capture) {
49975052
[encoder popDebugGroup];

0 commit comments

Comments
 (0)