@@ -636,6 +636,7 @@ struct vk_flash_attn_push_constants {
636636 uint32_t nev3;
637637 uint32_t nem1;
638638 uint32_t nem2;
639+ uint32_t nem3;
639640
640641 uint32_t nb01;
641642 uint32_t nb02;
@@ -651,8 +652,7 @@ struct vk_flash_attn_push_constants {
651652 float max_bias;
652653 float logit_softcap;
653654
654- uint32_t mask;
655- uint32_t n_head_log2;
655+ uint32_t mask_n_head_log2;
656656 float m0;
657657 float m1;
658658
@@ -6111,6 +6111,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61116111
61126112 const uint32_t nem1 = mask ? mask->ne[1] : 0;
61136113 const uint32_t nem2 = mask ? mask->ne[2] : 0;
6114+ const uint32_t nem3 = mask ? mask->ne[3] : 0;
61146115
61156116 const uint32_t HSK = nek0;
61166117 const uint32_t HSV = nev0;
@@ -6178,7 +6179,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61786179 }
61796180
61806181 if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
6181- qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 = = 1) {
6182+ qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 < = 1) {
61826183 // grouped query attention - make the N dimension equal to gqa_ratio, reduce
61836184 // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
61846185 // and change addressing calculations to index Q's dimension 2.
@@ -6348,17 +6349,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
63486349 }
63496350 }
63506351
6352+ uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
6353+
63516354 const vk_flash_attn_push_constants pc = { N, KV,
63526355 (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
63536356 (uint32_t)neq2, (uint32_t)neq3,
63546357 (uint32_t)nek2, (uint32_t)nek3,
63556358 (uint32_t)nev2, (uint32_t)nev3,
6356- nem1, nem2,
6359+ nem1, nem2, nem3,
63576360 q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
63586361 k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
63596362 v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
63606363 scale, max_bias, logit_softcap,
6361- mask != nullptr, n_head_log2 , m0, m1,
6364+ mask_n_head_log2 , m0, m1,
63626365 gqa_ratio, split_kv, split_k };
63636366
63646367 ggml_vk_sync_buffers(subctx);
@@ -10303,12 +10306,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1030310306 if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
1030410307 return false;
1030510308 }
10306- // TODO: support broadcast
10307- // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
10308- // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
10309- if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
10310- return false;
10311- }
1031210309 // It's straightforward to support different K/V dequant, but would
1031310310 // significantly increase the number of pipelines
1031410311 if (op->src[1]->type != op->src[2]->type) {
0 commit comments