@@ -1345,6 +1345,17 @@ @implementation GGMLMetalClass
13451345 return res;
13461346}
13471347
1348+ // return true if we should use the FA vector kernel for this op
1349+ static bool ggml_metal_flash_attn_ext_use_vec (const struct ggml_tensor * op) {
1350+ assert (op->op == GGML_OP_FLASH_ATTN_EXT);
1351+
1352+ const int64_t ne00 = op->ne [0 ]; // head size
1353+ const int64_t ne01 = op->ne [1 ]; // batch size
1354+
1355+ // use vec kernel if the batch size is small and if the head size is supported
1356+ return (ne01 < 20 ) && (ne00 % 32 == 0 );
1357+ }
1358+
13481359static id <MTLComputePipelineState > ggml_metal_get_pipeline_flash_attn_ext (
13491360 ggml_backend_t backend, struct ggml_tensor * op,
13501361 bool has_mask,
@@ -5066,9 +5077,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
50665077
50675078 GGML_ASSERT (ne01 < 65536 );
50685079
5069- // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
5070- // [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
5071- if (ne01 >= 20 || (ne00 % 32 != 0 )) {
5080+ if (!ggml_metal_flash_attn_ext_use_vec (dst)) {
50725081 // half8x8 kernel
50735082 const int64_t nqptg = 8 ; // queries per threadgroup !! sync with kernel template arguments !!
50745083 const int64_t ncpsg = 64 ; // cache values per simdgroup !! sync with kernel template arguments !!
@@ -5293,14 +5302,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
52935302 GGML_ASSERT (ne01*ne02*ne03 == ne1*ne2*ne3);
52945303 GGML_ASSERT (ne1*ne2*ne3 <= (1u << 31 ));
52955304
5296- const int32_t nrows = ne1*ne2*ne3;
5297-
5298- // [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
5305+ // [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
52995306 const size_t offs_tmp = offs_dst + ggml_nbytes (dst);
53005307
5301- // printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
5302- // printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
5303-
53045308 [encoder setBuffer: id_dst offset: offs_tmp atIndex: 6 ];
53055309
53065310 [encoder setThreadgroupMemoryLength: smem atIndex: 0 ];
@@ -5311,6 +5315,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
53115315
53125316 // reduce the results from the workgroups
53135317 {
5318+ const int32_t nrows = ne1*ne2*ne3;
5319+
53145320 ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
53155321 nrows,
53165322 };
@@ -6149,30 +6155,31 @@ static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_bu
61496155static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
61506156 size_t res = ggml_nbytes (tensor);
61516157
6158+ // some operations require additional memory for fleeting data:
61526159 switch (tensor->op ) {
61536160 case GGML_OP_MUL_MAT_ID:
61546161 {
6155- // [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
6156- const int64_t ne02 = tensor->src [0 ]->ne [2 ];
6157- const int64_t ne21 = tensor->src [2 ]->ne [1 ];
6162+ const int64_t ne02 = tensor->src [0 ]->ne [2 ]; // n_expert
6163+ const int64_t ne21 = tensor->src [2 ]->ne [1 ]; // n_token
61586164
6165+ // [TAG_METAL_EXTRA_SIZE_OP_MUL_MAT_ID]
61596166 res += ggml_type_size (GGML_TYPE_I32)*ne02;
6160- res += ggml_type_size (GGML_TYPE_I32)*ne21* ne02;
6167+ res += ggml_type_size (GGML_TYPE_I32)*ne02*ne21 ;
61616168 } break ;
61626169 case GGML_OP_FLASH_ATTN_EXT:
61636170 {
6164- // [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
6165- const int64_t nwg = 32 ;
6171+ if ( ggml_metal_flash_attn_ext_use_vec (tensor)) {
6172+ const int64_t nwg = 32 ;
61666173
6167- const int64_t ne01 = tensor->src [0 ]->ne [1 ];
6168- const int64_t ne02 = tensor->src [0 ]->ne [2 ];
6169- const int64_t ne03 = tensor->src [0 ]->ne [3 ];
6170- const int64_t ne20 = tensor->src [2 ]->ne [0 ];
6174+ const int64_t ne01 = tensor->src [0 ]->ne [1 ];
6175+ const int64_t ne02 = tensor->src [0 ]->ne [2 ];
6176+ const int64_t ne03 = tensor->src [0 ]->ne [3 ];
6177+ const int64_t ne20 = tensor->src [2 ]->ne [0 ];
61716178
6172- if (ne01 < 20 ) {
61736179 // temp buffer for writing the results from each workgroup
6174- // - ne20: the size of the head vector
6180+ // - ne20: the size of the Value head
61756181 // - + 2: the S and M values for each intermediate result
6182+ // [TAG_METAL_EXTRA_SIZE_OP_FLASH_ATTN_EXT]
61766183 res += ggml_type_size (GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2 ));
61776184 }
61786185 } break ;
0 commit comments