Skip to content

Commit d7568a2

Browse files
committed
metal : remove mem pool usage
ggml-ci
1 parent 261e6a2 commit d7568a2

File tree

1 file changed

+54
-90
lines changed

1 file changed

+54
-90
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 54 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -2522,7 +2522,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
25222522
/*.nb02 =*/ nb02,
25232523
/*.nb11 =*/ nb11,
25242524
/*.nb21 =*/ nb21,
2525-
25262525
};
25272526

25282527
[encoder setComputePipelineState:pipeline];
@@ -3167,54 +3166,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
31673166
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
31683167
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
31693168

3170-
// use this branch to test the ggml_metal_mem_pool functionality
3171-
#if 0
3172-
// cpy to tmp buffer in MTLHeap
3173-
3174-
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
3175-
if (!h_src0) {
3176-
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
3177-
return 0;
3178-
}
3179-
3180-
offs_src0 = 0;
3181-
3182-
ggml_metal_kargs_cpy args_cpy = {
3183-
/*.ne00 =*/ ne00,
3184-
/*.ne01 =*/ ne01,
3185-
/*.ne02 =*/ ne02,
3186-
/*.ne03 =*/ ne03,
3187-
/*.nb00 =*/ nb00,
3188-
/*.nb01 =*/ nb01,
3189-
/*.nb02 =*/ nb02,
3190-
/*.nb03 =*/ nb03,
3191-
/*.ne0 =*/ ne00,
3192-
/*.ne1 =*/ ne01,
3193-
/*.ne2 =*/ ne02,
3194-
/*.ne3 =*/ ne03,
3195-
/*.nb0 =*/ nb00,
3196-
/*.nb1 =*/ nb01,
3197-
/*.nb2 =*/ nb02,
3198-
/*.nb3 =*/ nb03,
3199-
};
3200-
3201-
if (src0->type == GGML_TYPE_F16) {
3202-
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
3203-
} else {
3204-
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
3205-
}
3206-
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
3207-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3208-
[encoder setBuffer:h_src0 offset:0 atIndex:2];
3209-
3210-
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
3211-
int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
3212-
3213-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
3214-
3215-
#else
32163169
id<MTLBuffer> h_src0 = id_src0;
3217-
#endif
3170+
32183171
// softmax
32193172

32203173
ggml_metal_kargs_soft_max args = {
@@ -4093,28 +4046,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
40934046
default: break;
40944047
}
40954048

4096-
// TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool
4097-
// reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
4098-
// so we add this extra barrier to prevent the race.
4099-
// the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
4100-
ggml_metal_encode_concurrency_reset(ctx_enc);
4101-
4102-
// tokens per expert
4103-
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
4104-
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
4105-
if (!h_tpe) {
4106-
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
4107-
return 0;
4108-
}
4109-
4110-
// id map
4111-
// [n_tokens, n_expert]
4112-
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
4113-
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
4114-
if (!h_ids) {
4115-
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
4116-
return 0;
4117-
}
4049+
// [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
4050+
size_t offs_tpe = offs_dst + ggml_nbytes(dst);
4051+
size_t offs_ids = offs_tpe + ggml_type_size(GGML_TYPE_I32)*ne02;
41184052

41194053
{
41204054
ggml_metal_kargs_mul_mm_id_map0 args = {
@@ -4152,8 +4086,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
41524086
[encoder setComputePipelineState:pipeline];
41534087
[encoder setBytes:&args length:sizeof(args) atIndex:0];
41544088
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
4155-
[encoder setBuffer: h_tpe offset:0 atIndex:2];
4156-
[encoder setBuffer: h_ids offset:0 atIndex:3];
4089+
[encoder setBuffer:id_dst offset:offs_tpe atIndex:2];
4090+
[encoder setBuffer:id_dst offset:offs_ids atIndex:3];
41574091
[encoder setThreadgroupMemoryLength:smem atIndex:0];
41584092

41594093
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
@@ -4215,8 +4149,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
42154149
[encoder setBytes:&args length:sizeof(args) atIndex:0];
42164150
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
42174151
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
4218-
[encoder setBuffer: h_tpe offset:0 atIndex:3];
4219-
[encoder setBuffer: h_ids offset:0 atIndex:4];
4152+
[encoder setBuffer:id_dst offset:offs_tpe atIndex:3];
4153+
[encoder setBuffer:id_dst offset:offs_ids atIndex:4];
42204154
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
42214155

42224156
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
@@ -5307,6 +5241,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
53075241
GGML_ASSERT(ne01 < 65536);
53085242

53095243
// use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
5244+
// [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
53105245
if (ne01 >= 20 || (ne00 % 32 != 0)) {
53115246
// half8x8 kernel
53125247
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
@@ -5532,30 +5467,20 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55325467
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
55335468
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
55345469

5535-
// using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
5536-
// still, we assume that concurrent FA won't happen before we do the refactor
5537-
//ggml_metal_encode_concurrency_reset(ctx_enc);
5538-
55395470
const int32_t nrows = ne1*ne2*ne3;
55405471

5541-
// temp buffer for writing the results from each workgroup
5542-
// - ne20: the size of the head vector
5543-
// - + 2: the S and M values for each intermediate result
5544-
const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
5545-
id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
5546-
if (!h_tmp) {
5547-
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
5548-
return 0;
5549-
}
5472+
// [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
5473+
const size_t offs_tmp = offs_dst + ggml_nbytes(dst);
55505474

55515475
//printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
55525476
//printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
55535477

5554-
[encoder setBuffer:h_tmp offset:0 atIndex:6];
5478+
[encoder setBuffer:id_dst offset:offs_tmp atIndex:6];
55555479

55565480
[encoder setThreadgroupMemoryLength:smem atIndex:0];
55575481
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
55585482

5483+
// sync the 2 kernels
55595484
ggml_metal_encode_concurrency_reset(ctx_enc);
55605485

55615486
// reduce the results from the workgroups
@@ -5568,7 +5493,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55685493

55695494
[encoder setComputePipelineState:pipeline0];
55705495
[encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
5571-
[encoder setBuffer:h_tmp offset:0 atIndex:1];
5496+
[encoder setBuffer:id_dst offset:offs_tmp atIndex:1];
55725497
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
55735498

55745499
//printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
@@ -6401,6 +6326,45 @@ static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_bu
64016326
return max_size;
64026327
}
64036328

6329+
static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
6330+
size_t res = ggml_nbytes(tensor);
6331+
6332+
switch (tensor->op) {
6333+
case GGML_OP_MUL_MAT_ID:
6334+
{
6335+
// [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
6336+
const int64_t ne02 = tensor->src[0]->ne[2];
6337+
const int64_t ne21 = tensor->src[2]->ne[1];
6338+
6339+
res += ggml_type_size(GGML_TYPE_I32)*ne02;
6340+
res += ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
6341+
} break;
6342+
case GGML_OP_FLASH_ATTN_EXT:
6343+
{
6344+
// [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
6345+
const int64_t nwg = 32;
6346+
6347+
const int64_t ne01 = tensor->src[0]->ne[1];
6348+
const int64_t ne02 = tensor->src[0]->ne[2];
6349+
const int64_t ne03 = tensor->src[0]->ne[3];
6350+
const int64_t ne20 = tensor->src[2]->ne[0];
6351+
6352+
if (ne01 < 20) {
6353+
// temp buffer for writing the results from each workgroup
6354+
// - ne20: the size of the head vector
6355+
// - + 2: the S and M values for each intermediate result
6356+
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
6357+
}
6358+
} break;
6359+
default:
6360+
break;
6361+
}
6362+
6363+
return res;
6364+
6365+
GGML_UNUSED(buft);
6366+
}
6367+
64046368
static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) {
64056369
return false;
64066370

@@ -6414,7 +6378,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) {
64146378
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
64156379
/* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment,
64166380
/* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size,
6417-
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
6381+
/* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,
64186382
/* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host,
64196383
},
64206384
/* .device = */ &g_ggml_backend_metal_device,

0 commit comments

Comments
 (0)