Skip to content

Commit bd8e0bf

Browse files
committed
vulkan: use smaller FA row size when head size is large. applies to both scalar and CM2 paths (CM1 isn't used due to shared memory limits)
1 parent 2b54086 commit bd8e0bf

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,7 +1735,14 @@ static FaHeadSizes fa_get_head_sizes(uint32_t hsk, uint32_t hsv) {
17351735
// number of rows/cols for flash attention shader
17361736
static constexpr uint32_t flash_attention_num_small_rows = 32;
17371737
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
1738-
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
1738+
1739+
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) {
1740+
if (hsv >= 512) {
1741+
return 2;
1742+
} else {
1743+
return 8;
1744+
}
1745+
}
17391746

17401747
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
17411748
// 128 threads split into four subgroups, each subgroup does 1/4
@@ -1760,7 +1767,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
17601767
if (small_rows) {
17611768
return {scalar_flash_attention_num_small_rows, 64};
17621769
} else {
1763-
return {scalar_flash_attention_num_large_rows, 32};
1770+
return {get_fa_scalar_num_large_rows(hsv), 32};
17641771
}
17651772
}
17661773

@@ -1779,7 +1786,11 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
17791786

17801787
// small cols to reduce register count
17811788
if (ggml_is_quantized(type) || hsk >= 256) {
1782-
return {64, 32};
1789+
if (hsk >= 512) {
1790+
return {32, 32};
1791+
} else {
1792+
return {64, 32};
1793+
}
17831794
}
17841795
return {64, 64};
17851796
}
@@ -6048,7 +6059,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
60486059
// Needs to be kept up to date on shader changes
60496060
GGML_UNUSED(hsv);
60506061
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
6051-
const uint32_t Br = scalar_flash_attention_num_large_rows;
6062+
const uint32_t Br = get_fa_scalar_num_large_rows(hsv);
60526063
const uint32_t Bc = scalar_flash_attention_Bc;
60536064

60546065
const uint32_t tmpsh = wg_size * sizeof(float);
@@ -6173,7 +6184,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61736184
case FA_SCALAR:
61746185
case FA_COOPMAT1:
61756186
// We may switch from coopmat1 to scalar, so use the scalar limit for both
6176-
max_gqa = scalar_flash_attention_num_large_rows;
6187+
max_gqa = get_fa_scalar_num_large_rows(HSV);
61776188
break;
61786189
case FA_COOPMAT2:
61796190
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);

0 commit comments

Comments
 (0)