@@ -1735,7 +1735,14 @@ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
17351735// number of rows/cols for flash attention shader
17361736static constexpr uint32_t flash_attention_num_small_rows = 32;
17371737static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
1738- static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
1738+
1739+ static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
1740+ if (hsv >= 512) {
1741+ return 2;
1742+ } else {
1743+ return 8;
1744+ }
1745+ }
17391746
17401747// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
17411748// 128 threads split into four subgroups, each subgroup does 1/4
@@ -1760,7 +1767,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
17601767 if (small_rows) {
17611768 return {scalar_flash_attention_num_small_rows, 64};
17621769 } else {
1763- return {scalar_flash_attention_num_large_rows , 32};
1770+ return {get_fa_scalar_num_large_rows(hsv) , 32};
17641771 }
17651772 }
17661773
@@ -1779,7 +1786,11 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
17791786
17801787 // small cols to reduce register count
17811788 if (ggml_is_quantized(type) || hsk >= 256) {
1782- return {64, 32};
1789+ if (hsk >= 512) {
1790+ return {32, 32};
1791+ } else {
1792+ return {64, 32};
1793+ }
17831794 }
17841795 return {64, 64};
17851796}
@@ -6048,7 +6059,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
60486059 // Needs to be kept up to date on shader changes
60496060 GGML_UNUSED(hsv);
60506061 const uint32_t wg_size = scalar_flash_attention_workgroup_size;
6051- const uint32_t Br = scalar_flash_attention_num_large_rows ;
6062+ const uint32_t Br = get_fa_scalar_num_large_rows(hsv) ;
60526063 const uint32_t Bc = scalar_flash_attention_Bc;
60536064
60546065 const uint32_t tmpsh = wg_size * sizeof(float);
@@ -6173,7 +6184,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61736184 case FA_SCALAR:
61746185 case FA_COOPMAT1:
61756186 // We may switch from coopmat1 to scalar, so use the scalar limit for both
6176- max_gqa = scalar_flash_attention_num_large_rows ;
6187+ max_gqa = get_fa_scalar_num_large_rows(HSV) ;
61776188 break;
61786189 case FA_COOPMAT2:
61796190 max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
0 commit comments