@@ -2501,9 +2501,11 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
25012501static constexpr uint32_t flash_attention_num_small_rows = 32;
25022502static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
25032503
2504- static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
2504+ static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) {
25052505 if (hsv >= 192) {
25062506 return 2;
2507+ } else if ((hsv | hsk) & 8) {
2508+ return 4;
25072509 } else {
25082510 return 8;
25092511 }
@@ -2535,9 +2537,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
25352537 if ((hsv | hsk) & 8) {
25362538 // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
25372539 // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
2538- return {get_fa_scalar_num_large_rows(hsv), 64};
2540+ return {get_fa_scalar_num_large_rows(hsk, hsv), 64};
25392541 } else {
2540- return {get_fa_scalar_num_large_rows(hsv), 32};
2542+ return {get_fa_scalar_num_large_rows(hsk, hsv), 32};
25412543 }
25422544 }
25432545 }
@@ -7740,7 +7742,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
77407742 // Needs to be kept up to date on shader changes
77417743 GGML_UNUSED(hsv);
77427744 const uint32_t wg_size = scalar_flash_attention_workgroup_size;
7743- const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
7745+ const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv);
77447746 const uint32_t Bc = scalar_flash_attention_Bc;
77457747
77467748 const uint32_t tmpsh = wg_size * sizeof(float);
@@ -7871,7 +7873,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
78717873 case FA_SCALAR:
78727874 case FA_COOPMAT1:
78737875 // We may switch from coopmat1 to scalar, so use the scalar limit for both
7874- max_gqa = get_fa_scalar_num_large_rows(HSV);
7876+ max_gqa = get_fa_scalar_num_large_rows(HSK, HSV);
78757877 break;
78767878 case FA_COOPMAT2:
78777879 max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
0 commit comments