Skip to content

Commit d414db0

Browse files
authored
vulkan: Use fewer rows for scalar FA when HS is not a multiple of 16 (ggml-org#17455)
1 parent 877566d commit d414db0

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2501,9 +2501,11 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
25012501
static constexpr uint32_t flash_attention_num_small_rows = 32;
25022502
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
25032503

2504-
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
2504+
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
25052505
if (hsv >= 192) {
25062506
return 2;
2507+
} else if ((hsv | hsk) & 8) {
2508+
return 4;
25072509
} else {
25082510
return 8;
25092511
}
@@ -2535,9 +2537,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
25352537
if ((hsv | hsk) & 8) {
25362538
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
25372539
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
2538-
return {get_fa_scalar_num_large_rows(hsv), 64};
2540+
return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
25392541
} else {
2540-
return {get_fa_scalar_num_large_rows(hsv), 32};
2542+
return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
25412543
}
25422544
}
25432545
}
@@ -7740,7 +7742,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
77407742
// Needs to be kept up to date on shader changes
77417743
GGML_UNUSED(hsv);
77427744
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
7743-
const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
7745+
const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
77447746
const uint32_t Bc = scalar_flash_attention_Bc;
77457747

77467748
const uint32_t tmpsh = wg_size * sizeof(float);
@@ -7871,7 +7873,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
78717873
case FA_SCALAR:
78727874
case FA_COOPMAT1:
78737875
// We may switch from coopmat1 to scalar, so use the scalar limit for both
7874-
max_gqa = get_fa_scalar_num_large_rows(HSV);
7876+
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
78757877
break;
78767878
case FA_COOPMAT2:
78777879
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);

tests/test-backend-ops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7859,6 +7859,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
78597859
}
78607860
}
78617861

7862+
// Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012
7863+
test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
7864+
78627865
for (int kv : { 4096, 8192, 16384, }) {
78637866
for (int hs : { 64, 128, }) {
78647867
for (int nr : { 1, 4, }) {

0 commit comments

Comments
 (0)