@@ -6982,8 +6982,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
69826982 const int64_t rv3 = neq3/nev3;
69836983
69846984 // parallelize by q rows using ggml_vec_dot_f32
6985+ const uint32_t n_head = neq2;
6986+ const uint32_t n_head_log2 = 1u << (uint32_t ) floor (log2 (n_head));
69856987
6986- const int n_gqa = neq2 / nek2;
6988+ const uint32_t n_kv_head = nek2;
6989+ const int n_gqa = n_head / n_kv_head;
69876990 GGML_ASSERT (n_gqa <= GGML_FLASH_ATTN_EXT_MAX_GQA);
69886991
69896992 // total groups in q
@@ -7008,9 +7011,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
70087011 scale /= logit_softcap;
70097012 }
70107013
7011- const uint32_t n_head = neq2;
7012- const uint32_t n_head_log2 = 1u << (uint32_t ) floor (log2 (n_head));
7013-
70147014 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
70157015 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
70167016
@@ -7031,8 +7031,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
70317031 float slope[GGML_FLASH_ATTN_EXT_MAX_GQA];
70327032
70337033 for (int ig = ig0; ig < ig1; ++ig) {
7034- const int group_index = ig % ng ;
7035- const int batch_index = ig / ng ;
7034+ const int group_index = ig % n_kv_head ;
7035+ const int batch_index = ig / n_kv_head ;
70367036 // q indices
70377037 const int iq3 = 0 ;
70387038 const int iq2 = group_index * n_gqa; // start head index
0 commit comments