Skip to content

Commit 45792af

Browse files
committed
allow gqa with dim3>1
1 parent 0cdb383 commit 45792af

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6120,7 +6120,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61206120
}
61216121

61226122
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
6123-
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1 && nem2 <= 1 && nem3 <= 1) {
6123+
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
61246124
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
61256125
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
61266126
// and change addressing calculations to index Q's dimension 2.

0 commit comments

Comments
 (0)