@@ -1345,6 +1345,17 @@ @implementation GGMLMetalClass
13451345 return res;
13461346}
13471347
1348+ // return true if we should use the FA vector kernel for this op
1349+ static bool ggml_metal_flash_attn_ext_use_vec (const struct ggml_tensor * op) {
1350+ assert (op->op == GGML_OP_FLASH_ATTN_EXT);
1351+
1352+ const int64_t ne00 = op->src [0 ]->ne [0 ]; // head size
1353+ const int64_t ne01 = op->src [0 ]->ne [1 ]; // batch size
1354+
1355+ // use vec kernel if the batch size is small and if the head size is supported
1356+ return (ne01 < 20 ) && (ne00 % 32 == 0 );
1357+ }
1358+
13481359static id <MTLComputePipelineState > ggml_metal_get_pipeline_flash_attn_ext (
13491360 ggml_backend_t backend, struct ggml_tensor * op,
13501361 bool has_mask,
@@ -5067,9 +5078,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
50675078
50685079 GGML_ASSERT (ne01 < 65536 );
50695080
5070- // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
5071- // [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
5072- if (ne01 >= 20 || (ne00 % 32 != 0 )) {
5081+ if (!ggml_metal_flash_attn_ext_use_vec (dst)) {
50735082 // half8x8 kernel
50745083 const int64_t nqptg = 8 ; // queries per threadgroup !! sync with kernel template arguments !!
50755084 const int64_t ncpsg = 64 ; // cache values per simdgroup !! sync with kernel template arguments !!
@@ -5294,14 +5303,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
52945303 GGML_ASSERT (ne01*ne02*ne03 == ne1*ne2*ne3);
52955304 GGML_ASSERT (ne1*ne2*ne3 <= (1u << 31 ));
52965305
5297- const int32_t nrows = ne1*ne2*ne3;
5298-
5299- // [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
5306+ // [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
53005307 const size_t offs_tmp = offs_dst + ggml_nbytes (dst);
53015308
5302- // printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
5303- // printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
5304-
53055309 [encoder setBuffer: id_dst offset: offs_tmp atIndex: 6 ];
53065310
53075311 [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
@@ -5312,6 +5316,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
53125316
53135317 // reduce the results from the workgroups
53145318 {
5319+ const int32_t nrows = ne1*ne2*ne3;
5320+
53155321 ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
53165322 nrows,
53175323 };
@@ -6150,30 +6156,31 @@ static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_bu
61506156static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
61516157 size_t res = ggml_nbytes (tensor);
61526158
6159+ // some operations require additional memory for fleeting data:
61536160 switch (tensor->op ) {
61546161 case GGML_OP_MUL_MAT_ID:
61556162 {
6156- // [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
6157- const int64_t ne02 = tensor->src [0 ]->ne [2 ];
6158- const int64_t ne21 = tensor->src [2 ]->ne [1 ];
6163+ const int64_t ne02 = tensor->src [0 ]->ne [2 ]; // n_expert
6164+ const int64_t ne21 = tensor->src [2 ]->ne [1 ]; // n_token
61596165
6166+ // [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
61606167 res += ggml_type_size (GGML_TYPE_I32)*ne02;
6161- res += ggml_type_size (GGML_TYPE_I32)*ne21* ne02;
6168+ res += ggml_type_size (GGML_TYPE_I32)*ne02*ne21 ;
61626169 } break ;
61636170 case GGML_OP_FLASH_ATTN_EXT:
61646171 {
6165- // [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
6166- const int64_t nwg = 32 ;
6172+ if ( ggml_metal_flash_attn_ext_use_vec (tensor)) {
6173+ const int64_t nwg = 32 ;
61676174
6168- const int64_t ne01 = tensor->src [0 ]->ne [1 ];
6169- const int64_t ne02 = tensor->src [0 ]->ne [2 ];
6170- const int64_t ne03 = tensor->src [0 ]->ne [3 ];
6171- const int64_t ne20 = tensor->src [2 ]->ne [0 ];
6175+ const int64_t ne01 = tensor->src [0 ]->ne [1 ];
6176+ const int64_t ne02 = tensor->src [0 ]->ne [2 ];
6177+ const int64_t ne03 = tensor->src [0 ]->ne [3 ];
6178+ const int64_t ne20 = tensor->src [2 ]->ne [0 ];
61726179
6173- if (ne01 < 20 ) {
61746180 // temp buffer for writing the results from each workgroup
6175- // - ne20: the size of the head vector
6181+ // - ne20: the size of the Value head
61766182 // - + 2: the S and M values for each intermediate result
6183+ // [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
61776184 res += ggml_type_size (GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2 ));
61786185 }
61796186 } break ;
0 commit comments