@@ -1345,6 +1345,23 @@ @implementation GGMLMetalClass
13451345 return res;
13461346}
13471347
1348+ static size_t ggml_metal_mul_mat_id_extra_tpe (const struct ggml_tensor * op) {
1349+ assert (op->op == GGML_OP_MUL_MAT_ID);
1350+
1351+ const int64_t ne02 = op->src [0 ]->ne [2 ]; // n_expert
1352+
1353+ return ggml_type_size (GGML_TYPE_I32)*ne02;
1354+ }
1355+
1356+ static size_t ggml_metal_mul_mat_id_extra_ids (const struct ggml_tensor * op) {
1357+ assert (op->op == GGML_OP_MUL_MAT_ID);
1358+
1359+ const int64_t ne02 = op->src [0 ]->ne [2 ]; // n_expert
1360+ const int64_t ne21 = op->src [2 ]->ne [1 ]; // n_token
1361+
1362+ return ggml_type_size (GGML_TYPE_I32)*ne02*ne21;
1363+ }
1364+
13481365// return true if we should use the FA vector kernel for this op
13491366static bool ggml_metal_flash_attn_ext_use_vec (const struct ggml_tensor * op) {
13501367 assert (op->op == GGML_OP_FLASH_ATTN_EXT);
@@ -1356,6 +1373,22 @@ static bool ggml_metal_flash_attn_ext_use_vec(const struct ggml_tensor * op) {
13561373 return (ne01 < 20 ) && (ne00 % 32 == 0 );
13571374}
13581375
1376+ static size_t ggml_metal_flash_attn_ext_extra_tmp (const struct ggml_tensor * op) {
1377+ assert (op->op == GGML_OP_FLASH_ATTN_EXT);
1378+
1379+ const int64_t nwg = 32 ;
1380+
1381+ const int64_t ne01 = op->src [0 ]->ne [1 ];
1382+ const int64_t ne02 = op->src [0 ]->ne [2 ];
1383+ const int64_t ne03 = op->src [0 ]->ne [3 ];
1384+ const int64_t ne20 = op->src [2 ]->ne [0 ];
1385+
1386+ // temp buffer for writing the results from each workgroup
1387+ // - ne20: the size of the Value head
1388+ // - + 2: the S and M values for each intermediate result
1389+ return ggml_type_size (GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2 ));
1390+ }
1391+
13591392static id <MTLComputePipelineState > ggml_metal_get_pipeline_flash_attn_ext (
13601393 ggml_backend_t backend, struct ggml_tensor * op,
13611394 bool has_mask,
@@ -3884,9 +3917,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
38843917 default : break ;
38853918 }
38863919
3887- // [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
3920+ // extra buffers for intermediate id mapping
38883921 size_t offs_tpe = offs_dst + ggml_nbytes (dst);
3889- size_t offs_ids = offs_tpe + ggml_type_size (GGML_TYPE_I32)*ne02 ;
3922+ size_t offs_ids = offs_tpe + ggml_metal_mul_mat_id_extra_tpe (dst) ;
38903923
38913924 {
38923925 ggml_metal_kargs_mul_mm_id_map0 args = {
@@ -5303,9 +5336,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
53035336 GGML_ASSERT (ne01*ne02*ne03 == ne1*ne2*ne3);
53045337 GGML_ASSERT (ne1*ne2*ne3 <= (1u << 31 ));
53055338
5306- // [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
5339+ // write the results from each workgroup into a temp buffer
53075340 const size_t offs_tmp = offs_dst + ggml_nbytes (dst);
5308-
53095341 [encoder setBuffer: id_dst offset: offs_tmp atIndex: 6 ];
53105342
53115343 [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
@@ -6160,28 +6192,13 @@ static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_
61606192 switch (tensor->op ) {
61616193 case GGML_OP_MUL_MAT_ID:
61626194 {
6163- const int64_t ne02 = tensor->src [0 ]->ne [2 ]; // n_expert
6164- const int64_t ne21 = tensor->src [2 ]->ne [1 ]; // n_token
6165-
6166- // [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
6167- res += ggml_type_size (GGML_TYPE_I32)*ne02;
6168- res += ggml_type_size (GGML_TYPE_I32)*ne02*ne21;
6195+ res += ggml_metal_mul_mat_id_extra_tpe (tensor);
6196+ res += ggml_metal_mul_mat_id_extra_ids (tensor);
61696197 } break ;
61706198 case GGML_OP_FLASH_ATTN_EXT:
61716199 {
61726200 if (ggml_metal_flash_attn_ext_use_vec (tensor)) {
6173- const int64_t nwg = 32 ;
6174-
6175- const int64_t ne01 = tensor->src [0 ]->ne [1 ];
6176- const int64_t ne02 = tensor->src [0 ]->ne [2 ];
6177- const int64_t ne03 = tensor->src [0 ]->ne [3 ];
6178- const int64_t ne20 = tensor->src [2 ]->ne [0 ];
6179-
6180- // temp buffer for writing the results from each workgroup
6181- // - ne20: the size of the Value head
6182- // - + 2: the S and M values for each intermediate result
6183- // [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
6184- res += ggml_type_size (GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2 ));
6201+ res += ggml_metal_flash_attn_ext_extra_tmp (tensor);
61856202 }
61866203 } break ;
61876204 default :
0 commit comments