Skip to content

Commit 29f5903

Browse files
committed
metal : remove mem pool usage
ggml-ci
1 parent 55758b0 commit 29f5903

File tree

1 file changed

+54
-90
lines changed

1 file changed

+54
-90
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 54 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -2521,7 +2521,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
25212521
/*.nb02 =*/ nb02,
25222522
/*.nb11 =*/ nb11,
25232523
/*.nb21 =*/ nb21,
2524-
25252524
};
25262525

25272526
[encoder setComputePipelineState:pipeline];
@@ -3166,54 +3165,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
31663165
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
31673166
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
31683167

3169-
// use this branch to test the ggml_metal_mem_pool functionality
3170-
#if 0
3171-
// cpy to tmp buffer in MTLHeap
3172-
3173-
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
3174-
if (!h_src0) {
3175-
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
3176-
return 0;
3177-
}
3178-
3179-
offs_src0 = 0;
3180-
3181-
ggml_metal_kargs_cpy args_cpy = {
3182-
/*.ne00 =*/ ne00,
3183-
/*.ne01 =*/ ne01,
3184-
/*.ne02 =*/ ne02,
3185-
/*.ne03 =*/ ne03,
3186-
/*.nb00 =*/ nb00,
3187-
/*.nb01 =*/ nb01,
3188-
/*.nb02 =*/ nb02,
3189-
/*.nb03 =*/ nb03,
3190-
/*.ne0 =*/ ne00,
3191-
/*.ne1 =*/ ne01,
3192-
/*.ne2 =*/ ne02,
3193-
/*.ne3 =*/ ne03,
3194-
/*.nb0 =*/ nb00,
3195-
/*.nb1 =*/ nb01,
3196-
/*.nb2 =*/ nb02,
3197-
/*.nb3 =*/ nb03,
3198-
};
3199-
3200-
if (src0->type == GGML_TYPE_F16) {
3201-
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
3202-
} else {
3203-
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
3204-
}
3205-
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
3206-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3207-
[encoder setBuffer:h_src0 offset:0 atIndex:2];
3208-
3209-
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
3210-
int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
3211-
3212-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
3213-
3214-
#else
32153168
id<MTLBuffer> h_src0 = id_src0;
3216-
#endif
3169+
32173170
// softmax
32183171

32193172
ggml_metal_kargs_soft_max args = {
@@ -4092,28 +4045,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
40924045
default: break;
40934046
}
40944047

4095-
// TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool
4096-
// reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
4097-
// so we add this extra barrier to prevent the race.
4098-
// the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
4099-
ggml_metal_encode_concurrency_reset(ctx_enc);
4100-
4101-
// tokens per expert
4102-
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
4103-
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
4104-
if (!h_tpe) {
4105-
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
4106-
return 0;
4107-
}
4108-
4109-
// id map
4110-
// [n_tokens, n_expert]
4111-
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
4112-
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
4113-
if (!h_ids) {
4114-
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
4115-
return 0;
4116-
}
4048+
// [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
4049+
size_t offs_tpe = offs_dst + ggml_nbytes(dst);
4050+
size_t offs_ids = offs_tpe + ggml_type_size(GGML_TYPE_I32)*ne02;
41174051

41184052
{
41194053
ggml_metal_kargs_mul_mm_id_map0 args = {
@@ -4151,8 +4085,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
41514085
[encoder setComputePipelineState:pipeline];
41524086
[encoder setBytes:&args length:sizeof(args) atIndex:0];
41534087
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
4154-
[encoder setBuffer: h_tpe offset:0 atIndex:2];
4155-
[encoder setBuffer: h_ids offset:0 atIndex:3];
4088+
[encoder setBuffer:id_dst offset:offs_tpe atIndex:2];
4089+
[encoder setBuffer:id_dst offset:offs_ids atIndex:3];
41564090
[encoder setThreadgroupMemoryLength:smem atIndex:0];
41574091

41584092
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
@@ -4214,8 +4148,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
42144148
[encoder setBytes:&args length:sizeof(args) atIndex:0];
42154149
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
42164150
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
4217-
[encoder setBuffer: h_tpe offset:0 atIndex:3];
4218-
[encoder setBuffer: h_ids offset:0 atIndex:4];
4151+
[encoder setBuffer:id_dst offset:offs_tpe atIndex:3];
4152+
[encoder setBuffer:id_dst offset:offs_ids atIndex:4];
42194153
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
42204154

42214155
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
@@ -5306,6 +5240,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
53065240
GGML_ASSERT(ne01 < 65536);
53075241

53085242
// use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
5243+
// [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
53095244
if (ne01 >= 20 || (ne00 % 32 != 0)) {
53105245
// half8x8 kernel
53115246
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
@@ -5531,30 +5466,20 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55315466
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
55325467
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
55335468

5534-
// using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
5535-
// still, we assume that concurrent FA won't happen before we do the refactor
5536-
//ggml_metal_encode_concurrency_reset(ctx_enc);
5537-
55385469
const int32_t nrows = ne1*ne2*ne3;
55395470

5540-
// temp buffer for writing the results from each workgroup
5541-
// - ne20: the size of the head vector
5542-
// - + 2: the S and M values for each intermediate result
5543-
const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
5544-
id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
5545-
if (!h_tmp) {
5546-
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
5547-
return 0;
5548-
}
5471+
// [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
5472+
const size_t offs_tmp = offs_dst + ggml_nbytes(dst);
55495473

55505474
//printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
55515475
//printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
55525476

5553-
[encoder setBuffer:h_tmp offset:0 atIndex:6];
5477+
[encoder setBuffer:id_dst offset:offs_tmp atIndex:6];
55545478

55555479
[encoder setThreadgroupMemoryLength:smem atIndex:0];
55565480
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
55575481

5482+
// sync the 2 kernels
55585483
ggml_metal_encode_concurrency_reset(ctx_enc);
55595484

55605485
// reduce the results from the workgroups
@@ -5567,7 +5492,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55675492

55685493
[encoder setComputePipelineState:pipeline0];
55695494
[encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
5570-
[encoder setBuffer:h_tmp offset:0 atIndex:1];
5495+
[encoder setBuffer:id_dst offset:offs_tmp atIndex:1];
55715496
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
55725497

55735498
//printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
@@ -6400,6 +6325,45 @@ static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_bu
64006325
return max_size;
64016326
}
64026327

6328+
static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
6329+
size_t res = ggml_nbytes(tensor);
6330+
6331+
switch (tensor->op) {
6332+
case GGML_OP_MUL_MAT_ID:
6333+
{
6334+
// [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
6335+
const int64_t ne02 = tensor->src[0]->ne[2];
6336+
const int64_t ne21 = tensor->src[2]->ne[1];
6337+
6338+
res += ggml_type_size(GGML_TYPE_I32)*ne02;
6339+
res += ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
6340+
} break;
6341+
case GGML_OP_FLASH_ATTN_EXT:
6342+
{
6343+
// [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
6344+
const int64_t nwg = 32;
6345+
6346+
const int64_t ne01 = tensor->src[0]->ne[1];
6347+
const int64_t ne02 = tensor->src[0]->ne[2];
6348+
const int64_t ne03 = tensor->src[0]->ne[3];
6349+
const int64_t ne20 = tensor->src[2]->ne[0];
6350+
6351+
if (ne01 < 20) {
6352+
// temp buffer for writing the results from each workgroup
6353+
// - ne20: the size of the head vector
6354+
// - + 2: the S and M values for each intermediate result
6355+
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
6356+
}
6357+
} break;
6358+
default:
6359+
break;
6360+
}
6361+
6362+
return res;
6363+
6364+
GGML_UNUSED(buft);
6365+
}
6366+
64036367
static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) {
64046368
return false;
64056369

@@ -6413,7 +6377,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) {
64136377
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
64146378
/* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment,
64156379
/* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size,
6416-
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
6380+
/* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,
64176381
/* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host,
64186382
},
64196383
/* .device = */ &g_ggml_backend_metal_device,

0 commit comments

Comments
 (0)