Skip to content

Commit 314e0e6

Browse files
committed
vulkan: allow FA split_k with smaller KV values
1 parent 6491d6e commit 314e0e6

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6252,13 +6252,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
62526252
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
62536253

62546254
// Try to use split_k when KV is large enough to be worth the overhead
6255-
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
6255+
if (workgroups_x == 1 && shader_core_count > 0) {
62566256
// Try to run two workgroups per SM.
62576257
split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
62586258
if (split_k > 1) {
62596259
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
62606260
// of "align", so recompute split_k based on that.
6261-
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
6261+
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), pipelines[1]->align);
62626262
split_k = CEIL_DIV(KV, split_kv);
62636263
workgroups_x = split_k;
62646264
}

0 commit comments

Comments
 (0)