@@ -1590,7 +1590,8 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
15901590
15911591// number of rows/cols for flash attention shader
15921592static constexpr uint32_t flash_attention_num_small_rows = 32;
1593- static constexpr uint32_t scalar_flash_attention_num_small_rows = 8;
1593+ static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
1594+ static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
15941595
15951596static uint32_t get_fa_num_small_rows(bool scalar) {
15961597 return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows;
@@ -1599,8 +1600,16 @@ static uint32_t get_fa_num_small_rows(bool scalar) {
15991600static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
16001601 GGML_UNUSED(clamp);
16011602
1603+ if (scalar) {
1604+ if (small_rows) {
1605+ return {scalar_flash_attention_num_small_rows, 64};
1606+ } else {
1607+ return {scalar_flash_attention_num_large_rows, 32};
1608+ }
1609+ }
1610+
16021611 // small rows, large cols
1603- if (small_rows || scalar ) {
1612+ if (small_rows) {
16041613 return {get_fa_num_small_rows(scalar), 32};
16051614 }
16061615
@@ -5729,8 +5738,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
57295738 assert(q->type == GGML_TYPE_F32);
57305739 assert(k->type == v->type);
57315740
5732- vk_pipeline *pipelines;
57335741 bool scalar = !ctx->device->coopmat2;
5742+
5743+ uint32_t gqa_ratio = 1;
5744+ uint32_t qk_ratio = neq2 / nek2;
5745+ uint32_t workgroups_x = (uint32_t)neq1;
5746+ uint32_t workgroups_y = (uint32_t)neq2;
5747+ uint32_t workgroups_z = (uint32_t)neq3;
5748+
5749+ // For scalar FA, we can use the "large" size to accommodate qga.
5750+ // For coopmat FA, we always use the small size (which is still pretty large for gqa).
5751+ const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false);
5752+
5753+ if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
5754+ qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
5755+ // grouped query attention - make the N dimension equal to gqa_ratio, reduce
5756+ // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
5757+ // and change addressing calculations to index Q's dimension 2.
5758+ gqa_ratio = qk_ratio;
5759+ N = gqa_ratio;
5760+ workgroups_y /= N;
5761+ }
5762+
5763+ vk_pipeline *pipelines;
57345764 // XXX TODO other backends may be changing accumulator precision to default to f32 soon
57355765 bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
57365766 bool small_rows = N <= get_fa_num_small_rows(scalar);
@@ -5776,24 +5806,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
57765806 vk_pipeline pipeline = pipelines[aligned];
57775807 assert(pipeline);
57785808
5779- uint32_t gqa_ratio = 1;
5780- uint32_t qk_ratio = neq2 / nek2;
5781- uint32_t workgroups_x = (uint32_t)neq1;
5782- uint32_t workgroups_y = (uint32_t)neq2;
5783- uint32_t workgroups_z = (uint32_t)neq3;
5784-
5785- const uint32_t max_gqa = get_fa_num_small_rows(scalar);
5786-
5787- if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
5788- qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
5789- // grouped query attention - make the N dimension equal to gqa_ratio, reduce
5790- // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
5791- // and change addressing calculations to index Q's dimension 2.
5792- gqa_ratio = qk_ratio;
5793- N = gqa_ratio;
5794- workgroups_y /= N;
5795- }
5796-
57975809 uint32_t split_kv = KV;
57985810 uint32_t split_k = 1;
57995811
0 commit comments