diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c3c8e954b12c6..d2d54eed6990c 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2222,8 +2222,8 @@ extern "C" { GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads); GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1); - // Enhanced flash attention with state tensor for S/M values - // s_m_state: [2, n_heads * q_len] tensor containing [M, S] pairs for each head/position + // Enhanced flash attention with state tensor for S/M values and accumulated numerator + // s_m_state: [2 + head_dim, n_heads * q_len] tensor containing [M, S, VKQ...] for each head/position GGML_API struct ggml_tensor * ggml_flash_attn_ext_with_state( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index a614b2001bf64..a3b7697259566 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7174,12 +7174,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_with_state( GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - // Validate state tensor format: [2, n_heads * q_len] - GGML_ASSERT(state != NULL); - GGML_ASSERT(state->ne[0] == 2); // [M, S] pairs - GGML_ASSERT(state->ne[1] == neq2 * neq1); // n_heads * q_len - GGML_ASSERT(state->type == GGML_TYPE_F32); - const int ith = params->ith; const int nth = params->nth; @@ -7187,6 +7181,12 @@ static void ggml_compute_forward_flash_attn_ext_f16_with_state( const int64_t DV = nev0; //> head_dim const int64_t N = neq1; //> q_len + // Validate state tensor format: [2 + DV, n_heads * q_len] + GGML_ASSERT(state != NULL); + GGML_ASSERT(state->ne[0] == DV + 2); // [M, S, VKQ...] + GGML_ASSERT(state->ne[1] == neq2 * neq1); // n_heads * q_len + GGML_ASSERT(state->type == GGML_TYPE_F32); + GGML_ASSERT(ne0 == DV); //> dst -> ne[0] == head_dim GGML_ASSERT(ne2 == N); //> dst -> ne[2] == q_len @@ -7267,17 +7267,12 @@ static void ggml_compute_forward_flash_attn_ext_f16_with_state( // Calculate state tensor offset for this head/position const int64_t state_idx = iq2 * neq1 + iq1; // head * q_len + position - float * state_data = (float *)state->data; - - // Read initial S and M values from state tensor - // State format: [M, S] for each head/position - float S = state_data[state_idx * 2 + 1]; // sum (index 1) - float M = state_data[state_idx * 2 + 0]; // maximum KQ value (index 0) - - // If this is the first call (indicated by M == -INFINITY), initialize properly - if (M == -INFINITY) { - S = 0.0f; - } + float * state_data = (float *) state->data; + float * state_row = state_data + state_idx * (DV + 2); + + // Read initial values + float M = state_row[0]; + float S = state_row[1]; float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer @@ -7285,9 +7280,20 @@ static void ggml_compute_forward_flash_attn_ext_f16_with_state( ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16 if (v->type == GGML_TYPE_F16) { - memset(VKQ16, 0, DV*sizeof(ggml_fp16_t)); + for (int64_t d = 0; d < DV; ++d) { + float val = (M == -INFINITY) ? 0.0f : state_row[2 + d]; + VKQ16[d] = ggml_fp32_to_fp16(val); + VKQ32[d] = val; // keep FP32 version for later + } } else { - memset(VKQ32, 0, DV*sizeof(float)); + for (int64_t d = 0; d < DV; ++d) { + VKQ32[d] = (M == -INFINITY) ? 0.0f : state_row[2 + d]; + } + } + + if (M == -INFINITY) { + S = 0.0f; + M = -INFINITY; } const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; @@ -7375,16 +7381,19 @@ static void ggml_compute_forward_flash_attn_ext_f16_with_state( S = S*ms + vs; // scale and increment sum with partial sum } - // Write updated S and M values back to state tensor - state_data[state_idx * 2 + 0] = M; // maximum KQ value (index 0) - state_data[state_idx * 2 + 1] = S; // sum (index 1) - if (v->type == GGML_TYPE_F16) { for (int64_t d = 0; d < DV; ++d) { VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); } } + // Write updated S, M and VKQ values back to state tensor (before normalization) + state_row[0] = M; + state_row[1] = S; + for (int64_t d = 0; d < DV; ++d) { + state_row[2 + d] = VKQ32[d]; + } + // V /= S const float S_inv = 1.0f / S; ggml_vec_scale_f32(DV, VKQ32, S_inv); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 39e78d7052ac2..d748536b177dd 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4616,9 +4616,9 @@ struct ggml_tensor * ggml_flash_attn_ext_with_state( GGML_ASSERT(mask); } - // Validate state tensor format: [2, n_heads * q_len] + // Validate state tensor format: [2 + head_dim, n_heads * q_len] GGML_ASSERT(s_m_state != NULL); - GGML_ASSERT(s_m_state->ne[0] == 2); // [M, S] pairs + GGML_ASSERT(s_m_state->ne[0] == v->ne[0] + 2); // [M, S, VKQ...] per head/position GGML_ASSERT(s_m_state->ne[1] == q->ne[2] * q->ne[1]); // n_heads * q_len GGML_ASSERT(s_m_state->type == GGML_TYPE_F32); diff --git a/tests/test-flash-attn-state.cpp b/tests/test-flash-attn-state.cpp index 7d1be7f02551f..31935d972099d 100644 --- a/tests/test-flash-attn-state.cpp +++ b/tests/test-flash-attn-state.cpp @@ -88,11 +88,16 @@ static float tensor_max_diff(ggml_tensor* a, ggml_tensor* b) { static void reset_state_tensor(ggml_tensor* state) { float* state_data = (float*)state->data; - size_t n_pairs = ggml_nelements(state) / 2; - - for (size_t i = 0; i < n_pairs; i++) { - state_data[i * 2 + 0] = -INFINITY; // M (max KQ value) - state_data[i * 2 + 1] = 0.0f; // S (sum) + const int row = state->ne[0]; // 2 + head_dim + const int head_dim = row - 2; + const int n_pairs = state->ne[1]; + + for (int i = 0; i < n_pairs; i++) { + state_data[i * row + 0] = -INFINITY; // M + state_data[i * row + 1] = 0.0f; // S + for (int d = 0; d < head_dim; ++d) { + state_data[i * row + 2 + d] = 0.0f; // VKQ + } } } @@ -145,8 +150,8 @@ int main() { const int padded_seq_len = GGML_PAD(seq_len, GGML_KQ_MASK_PAD); ggml_tensor * mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, padded_kv_len, padded_seq_len); - // Create state tensor: [2, n_heads * seq_len] for [M, S] pairs - ggml_tensor * state = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, n_heads * seq_len); + // Create state tensor: [2 + head_dim, n_heads * seq_len] for [M, S, VKQ] + ggml_tensor * state = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2 + head_dim, n_heads * seq_len); print_tensor_info("Q", q); print_tensor_info("K", k); @@ -230,8 +235,9 @@ int main() { // Print state before this segment printf(" State before segment %d: ", seg + 1); float* state_data = (float*)state->data; + int row = state->ne[0]; for (int i = 0; i < std::min(4, n_heads * seq_len); i++) { - printf("[M=%.3f,S=%.3f] ", state_data[i * 2 + 0], state_data[i * 2 + 1]); + printf("[M=%.3f,S=%.3f] ", state_data[i * row + 0], state_data[i * row + 1]); } printf("...\n"); @@ -321,7 +327,7 @@ int main() { // Print state after this segment printf(" State after segment %d: ", seg + 1); for (int i = 0; i < std::min(4, n_heads * seq_len); i++) { - printf("[M=%.3f,S=%.3f] ", state_data[i * 2 + 0], state_data[i * 2 + 1]); + printf("[M=%.3f,S=%.3f] ", state_data[i * row + 0], state_data[i * row + 1]); } printf("...\n"); @@ -367,12 +373,13 @@ int main() { print_f32_sample("Final state", state, 16); float* state_data = (float*)state->data; + int row = state->ne[0]; float min_m = INFINITY, max_m = -INFINITY; float min_s = INFINITY, max_s = -INFINITY; for (int i = 0; i < n_heads * seq_len; i++) { - float m_val = state_data[i * 2 + 0]; - float s_val = state_data[i * 2 + 1]; + float m_val = state_data[i * row + 0]; + float s_val = state_data[i * row + 1]; if (m_val != -INFINITY) { min_m = std::min(min_m, m_val);