@@ -1345,6 +1345,7 @@ @implementation GGMLMetalClass
13451345 return res;
13461346}
13471347
1348+ // tokens per expert
13481349static size_t ggml_metal_mul_mat_id_extra_tpe (const struct ggml_tensor * op) {
13491350 assert (op->op == GGML_OP_MUL_MAT_ID);
13501351
@@ -1353,6 +1354,7 @@ static size_t ggml_metal_mul_mat_id_extra_tpe(const struct ggml_tensor * op) {
13531354 return ggml_type_size (GGML_TYPE_I32)*ne02;
13541355}
13551356
1357+ // id map [n_tokens, n_expert]
13561358static size_t ggml_metal_mul_mat_id_extra_ids (const struct ggml_tensor * op) {
13571359 assert (op->op == GGML_OP_MUL_MAT_ID);
13581360
@@ -6161,6 +6163,31 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
61616163 return ggml_backend_buffer_init (buft, buf_i, ctx, size);
61626164}
61636165
6166+ static size_t ggml_backend_metal_buffer_type_get_alloc_size (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
6167+ size_t res = ggml_nbytes (tensor);
6168+
6169+ // some operations require additional memory for fleeting data:
6170+ switch (tensor->op ) {
6171+ case GGML_OP_MUL_MAT_ID:
6172+ {
6173+ res += ggml_metal_mul_mat_id_extra_tpe (tensor);
6174+ res += ggml_metal_mul_mat_id_extra_ids (tensor);
6175+ } break ;
6176+ case GGML_OP_FLASH_ATTN_EXT:
6177+ {
6178+ if (ggml_metal_flash_attn_ext_use_vec (tensor)) {
6179+ res += ggml_metal_flash_attn_ext_extra_tmp (tensor);
6180+ }
6181+ } break ;
6182+ default :
6183+ break ;
6184+ }
6185+
6186+ return res;
6187+
6188+ GGML_UNUSED (buft);
6189+ }
6190+
61646191// default (shared) buffer type
61656192
61666193static const char * ggml_backend_metal_buffer_type_shared_get_name (ggml_backend_buffer_type_t buft) {
@@ -6186,28 +6213,7 @@ static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_bu
61866213}
61876214
61886215static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
6189- size_t res = ggml_nbytes (tensor);
6190-
6191- // some operations require additional memory for fleeting data:
6192- switch (tensor->op ) {
6193- case GGML_OP_MUL_MAT_ID:
6194- {
6195- res += ggml_metal_mul_mat_id_extra_tpe (tensor);
6196- res += ggml_metal_mul_mat_id_extra_ids (tensor);
6197- } break ;
6198- case GGML_OP_FLASH_ATTN_EXT:
6199- {
6200- if (ggml_metal_flash_attn_ext_use_vec (tensor)) {
6201- res += ggml_metal_flash_attn_ext_extra_tmp (tensor);
6202- }
6203- } break ;
6204- default :
6205- break ;
6206- }
6207-
6208- return res;
6209-
6210- GGML_UNUSED (buft);
6216+ return ggml_backend_metal_buffer_type_get_alloc_size (buft, tensor);
62116217}
62126218
62136219static bool ggml_backend_metal_buffer_type_shared_is_host (ggml_backend_buffer_type_t buft) {
@@ -6257,6 +6263,10 @@ static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_b
62576263 return max_size;
62586264}
62596265
6266+ static size_t ggml_backend_metal_buffer_type_private_get_alloc_size (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
6267+ return ggml_backend_metal_buffer_type_get_alloc_size (buft, tensor);
6268+ }
6269+
62606270static bool ggml_backend_metal_buffer_type_private_is_host (ggml_backend_buffer_type_t buft) {
62616271 return false ;
62626272
@@ -6270,7 +6280,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) {
62706280 /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer,
62716281 /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment,
62726282 /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size,
6273- /* .get_alloc_size = */ NULL , // defaults to ggml_nbytes
6283+ /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size,
62746284 /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host,
62756285 },
62766286 /* .device = */ &g_ggml_backend_metal_device,
@@ -6305,6 +6315,10 @@ static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_bu
63056315 return max_size;
63066316}
63076317
6318+ static size_t ggml_backend_metal_buffer_type_mapped_get_alloc_size (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
6319+ return ggml_backend_metal_buffer_type_get_alloc_size (buft, tensor);
6320+ }
6321+
63086322static bool ggml_backend_metal_buffer_type_mapped_is_host (ggml_backend_buffer_type_t buft) {
63096323 return false ;
63106324
@@ -6320,7 +6334,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) {
63206334 /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,
63216335 /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment,
63226336 /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size,
6323- /* .get_alloc_size = */ NULL , // defaults to ggml_nbytes
6337+ /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size,
63246338 /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host,
63256339 },
63266340 /* .device = */ &g_ggml_backend_metal_device,
0 commit comments