Skip to content

Commit 54b99d2

Browse files
committed
fix
Signed-off-by: ZelinMa557 <[email protected]>
1 parent d3b5101 commit 54b99d2

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6982,8 +6982,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
69826982
const int64_t rv3 = neq3/nev3;
69836983

69846984
// parallelize by q rows using ggml_vec_dot_f32
6985+
const uint32_t n_head = neq2;
6986+
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
69856987

6986-
const int n_gqa = neq2 / nek2;
6988+
const uint32_t n_kv_head = nek2;
6989+
const int n_gqa = n_head / n_kv_head;
69876990
GGML_ASSERT(n_gqa <= GGML_FLASH_ATTN_EXT_MAX_GQA);
69886991

69896992
// total groups in q
@@ -7008,9 +7011,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
70087011
scale /= logit_softcap;
70097012
}
70107013

7011-
const uint32_t n_head = neq2;
7012-
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
7013-
70147014
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
70157015
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
70167016

@@ -7031,8 +7031,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
70317031
float slope[GGML_FLASH_ATTN_EXT_MAX_GQA];
70327032

70337033
for (int ig = ig0; ig < ig1; ++ig) {
7034-
const int group_index = ig % ng;
7035-
const int batch_index = ig / ng;
7034+
const int group_index = ig % n_kv_head;
7035+
const int batch_index = ig / n_kv_head;
70367036
// q indices
70377037
const int iq3 = 0;
70387038
const int iq2 = group_index * n_gqa; // start head index

0 commit comments

Comments
 (0)