Skip to content

Commit fcc113f

Browse files
committed
metal : implement .get_alloc_size for the rest of the buffer types
ggml-ci
1 parent c3827e6 commit fcc113f

File tree

1 file changed

+38
-24
lines changed

1 file changed

+38
-24
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,7 @@ @implementation GGMLMetalClass
13451345
return res;
13461346
}
13471347

1348+
// tokens per expert
13481349
static size_t ggml_metal_mul_mat_id_extra_tpe(const struct ggml_tensor * op) {
13491350
assert(op->op == GGML_OP_MUL_MAT_ID);
13501351

@@ -1353,6 +1354,7 @@ static size_t ggml_metal_mul_mat_id_extra_tpe(const struct ggml_tensor * op) {
13531354
return ggml_type_size(GGML_TYPE_I32)*ne02;
13541355
}
13551356

1357+
// id map [n_tokens, n_expert]
13561358
static size_t ggml_metal_mul_mat_id_extra_ids(const struct ggml_tensor * op) {
13571359
assert(op->op == GGML_OP_MUL_MAT_ID);
13581360

@@ -6161,6 +6163,31 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
61616163
return ggml_backend_buffer_init(buft, buf_i, ctx, size);
61626164
}
61636165

6166+
static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
6167+
size_t res = ggml_nbytes(tensor);
6168+
6169+
// some operations require additional memory for fleeting data:
6170+
switch (tensor->op) {
6171+
case GGML_OP_MUL_MAT_ID:
6172+
{
6173+
res += ggml_metal_mul_mat_id_extra_tpe(tensor);
6174+
res += ggml_metal_mul_mat_id_extra_ids(tensor);
6175+
} break;
6176+
case GGML_OP_FLASH_ATTN_EXT:
6177+
{
6178+
if (ggml_metal_flash_attn_ext_use_vec(tensor)) {
6179+
res += ggml_metal_flash_attn_ext_extra_tmp(tensor);
6180+
}
6181+
} break;
6182+
default:
6183+
break;
6184+
}
6185+
6186+
return res;
6187+
6188+
GGML_UNUSED(buft);
6189+
}
6190+
61646191
// default (shared) buffer type
61656192

61666193
static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) {
@@ -6186,28 +6213,7 @@ static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_bu
61866213
}
61876214

61886215
static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
6189-
size_t res = ggml_nbytes(tensor);
6190-
6191-
// some operations require additional memory for fleeting data:
6192-
switch (tensor->op) {
6193-
case GGML_OP_MUL_MAT_ID:
6194-
{
6195-
res += ggml_metal_mul_mat_id_extra_tpe(tensor);
6196-
res += ggml_metal_mul_mat_id_extra_ids(tensor);
6197-
} break;
6198-
case GGML_OP_FLASH_ATTN_EXT:
6199-
{
6200-
if (ggml_metal_flash_attn_ext_use_vec(tensor)) {
6201-
res += ggml_metal_flash_attn_ext_extra_tmp(tensor);
6202-
}
6203-
} break;
6204-
default:
6205-
break;
6206-
}
6207-
6208-
return res;
6209-
6210-
GGML_UNUSED(buft);
6216+
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
62116217
}
62126218

62136219
static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) {
@@ -6257,6 +6263,10 @@ static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_b
62576263
return max_size;
62586264
}
62596265

6266+
static size_t ggml_backend_metal_buffer_type_private_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
6267+
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
6268+
}
6269+
62606270
static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) {
62616271
return false;
62626272

@@ -6270,7 +6280,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) {
62706280
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer,
62716281
/* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment,
62726282
/* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size,
6273-
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
6283+
/* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size,
62746284
/* .is_host = */ ggml_backend_metal_buffer_type_private_is_host,
62756285
},
62766286
/* .device = */ &g_ggml_backend_metal_device,
@@ -6305,6 +6315,10 @@ static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_bu
63056315
return max_size;
63066316
}
63076317

6318+
static size_t ggml_backend_metal_buffer_type_mapped_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
6319+
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
6320+
}
6321+
63086322
static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) {
63096323
return false;
63106324

@@ -6320,7 +6334,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) {
63206334
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,
63216335
/* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment,
63226336
/* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size,
6323-
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
6337+
/* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size,
63246338
/* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host,
63256339
},
63266340
/* .device = */ &g_ggml_backend_metal_device,

0 commit comments

Comments
 (0)