@@ -634,6 +634,7 @@ struct vk_flash_attn_push_constants {
634634 uint32_t nev3;
635635 uint32_t nem1;
636636 uint32_t nem2;
637+ uint32_t nem3;
637638
638639 uint32_t nb01;
639640 uint32_t nb02;
@@ -649,8 +650,7 @@ struct vk_flash_attn_push_constants {
649650 float max_bias;
650651 float logit_softcap;
651652
652- uint32_t mask;
653- uint32_t n_head_log2;
653+ uint32_t mask_n_head_log2;
654654 float m0;
655655 float m1;
656656
@@ -6050,6 +6050,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
60506050
60516051 const uint32_t nem1 = mask ? mask->ne[1] : 0;
60526052 const uint32_t nem2 = mask ? mask->ne[2] : 0;
6053+ const uint32_t nem3 = mask ? mask->ne[3] : 0;
60536054
60546055 const uint32_t D = neq0;
60556056 uint32_t N = neq1;
@@ -6119,7 +6120,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61196120 }
61206121
61216122 if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
6122- qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
6123+ qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1 && nem2 == 1 && nem3 == 1 ) {
61236124 // grouped query attention - make the N dimension equal to gqa_ratio, reduce
61246125 // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
61256126 // and change addressing calculations to index Q's dimension 2.
@@ -6311,17 +6312,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
63116312 }
63126313 }
63136314
6315+ uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
6316+
63146317 const vk_flash_attn_push_constants pc = { N, KV,
63156318 (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
63166319 (uint32_t)neq2, (uint32_t)neq3,
63176320 (uint32_t)nek2, (uint32_t)nek3,
63186321 (uint32_t)nev2, (uint32_t)nev3,
6319- nem1, nem2,
6322+ nem1, nem2, nem3,
63206323 q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
63216324 k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
63226325 v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
63236326 scale, max_bias, logit_softcap,
6324- mask != nullptr, n_head_log2 , m0, m1,
6327+ mask_n_head_log2 , m0, m1,
63256328 gqa_ratio, split_kv, split_k };
63266329
63276330 ggml_vk_sync_buffers(subctx);
@@ -10265,12 +10268,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1026510268 if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
1026610269 return false;
1026710270 }
10268- // TODO: support broadcast
10269- // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
10270- // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
10271- if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
10272- return false;
10273- }
1027410271 // It's straightforward to support different K/V dequant, but would
1027510272 // significantly increase the number of pipelines
1027610273 if (op->src[1]->type != op->src[2]->type) {
0 commit comments