Skip to content

Commit 14d8c54

Browse files
committed
cont : add functions for the extra tensor sizes
1 parent 389e7e4 commit 14d8c54

File tree

1 file changed

+39
-22
lines changed

1 file changed

+39
-22
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,23 @@ @implementation GGMLMetalClass
13451345
return res;
13461346
}
13471347

1348+
static size_t ggml_metal_mul_mat_id_extra_tpe(const struct ggml_tensor * op) {
1349+
assert(op->op == GGML_OP_MUL_MAT_ID);
1350+
1351+
const int64_t ne02 = op->src[0]->ne[2]; // n_expert
1352+
1353+
return ggml_type_size(GGML_TYPE_I32)*ne02;
1354+
}
1355+
1356+
static size_t ggml_metal_mul_mat_id_extra_ids(const struct ggml_tensor * op) {
1357+
assert(op->op == GGML_OP_MUL_MAT_ID);
1358+
1359+
const int64_t ne02 = op->src[0]->ne[2]; // n_expert
1360+
const int64_t ne21 = op->src[2]->ne[1]; // n_token
1361+
1362+
return ggml_type_size(GGML_TYPE_I32)*ne02*ne21;
1363+
}
1364+
13481365
// return true if we should use the FA vector kernel for this op
13491366
static bool ggml_metal_flash_attn_ext_use_vec(const struct ggml_tensor * op) {
13501367
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1356,6 +1373,22 @@ static bool ggml_metal_flash_attn_ext_use_vec(const struct ggml_tensor * op) {
13561373
return (ne01 < 20) && (ne00 % 32 == 0);
13571374
}
13581375

1376+
static size_t ggml_metal_flash_attn_ext_extra_tmp(const struct ggml_tensor * op) {
1377+
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1378+
1379+
const int64_t nwg = 32;
1380+
1381+
const int64_t ne01 = op->src[0]->ne[1];
1382+
const int64_t ne02 = op->src[0]->ne[2];
1383+
const int64_t ne03 = op->src[0]->ne[3];
1384+
const int64_t ne20 = op->src[2]->ne[0];
1385+
1386+
// temp buffer for writing the results from each workgroup
1387+
// - ne20: the size of the Value head
1388+
// - + 2: the S and M values for each intermediate result
1389+
return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
1390+
}
1391+
13591392
static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext(
13601393
ggml_backend_t backend, struct ggml_tensor * op,
13611394
bool has_mask,
@@ -3884,9 +3917,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
38843917
default: break;
38853918
}
38863919

3887-
// [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
3920+
// extra buffers for intermediate id mapping
38883921
size_t offs_tpe = offs_dst + ggml_nbytes(dst);
3889-
size_t offs_ids = offs_tpe + ggml_type_size(GGML_TYPE_I32)*ne02;
3922+
size_t offs_ids = offs_tpe + ggml_metal_mul_mat_id_extra_tpe(dst);
38903923

38913924
{
38923925
ggml_metal_kargs_mul_mm_id_map0 args = {
@@ -5303,9 +5336,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
53035336
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
53045337
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
53055338

5306-
// [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
5339+
// write the results from each workgroup into a temp buffer
53075340
const size_t offs_tmp = offs_dst + ggml_nbytes(dst);
5308-
53095341
[encoder setBuffer:id_dst offset:offs_tmp atIndex:6];
53105342

53115343
[encoder setThreadgroupMemoryLength:smem atIndex:0];
@@ -6160,28 +6192,13 @@ static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_
61606192
switch (tensor->op) {
61616193
case GGML_OP_MUL_MAT_ID:
61626194
{
6163-
const int64_t ne02 = tensor->src[0]->ne[2]; // n_expert
6164-
const int64_t ne21 = tensor->src[2]->ne[1]; // n_token
6165-
6166-
// [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
6167-
res += ggml_type_size(GGML_TYPE_I32)*ne02;
6168-
res += ggml_type_size(GGML_TYPE_I32)*ne02*ne21;
6195+
res += ggml_metal_mul_mat_id_extra_tpe(tensor);
6196+
res += ggml_metal_mul_mat_id_extra_ids(tensor);
61696197
} break;
61706198
case GGML_OP_FLASH_ATTN_EXT:
61716199
{
61726200
if (ggml_metal_flash_attn_ext_use_vec(tensor)) {
6173-
const int64_t nwg = 32;
6174-
6175-
const int64_t ne01 = tensor->src[0]->ne[1];
6176-
const int64_t ne02 = tensor->src[0]->ne[2];
6177-
const int64_t ne03 = tensor->src[0]->ne[3];
6178-
const int64_t ne20 = tensor->src[2]->ne[0];
6179-
6180-
// temp buffer for writing the results from each workgroup
6181-
// - ne20: the size of the Value head
6182-
// - + 2: the S and M values for each intermediate result
6183-
// [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
6184-
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
6201+
res += ggml_metal_flash_attn_ext_extra_tmp(tensor);
61856202
}
61866203
} break;
61876204
default:

0 commit comments

Comments
 (0)