Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2222,8 +2222,8 @@ extern "C" {
GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads);
GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1);

// Enhanced flash attention with state tensor for S/M values
// s_m_state: [2, n_heads * q_len] tensor containing [M, S] pairs for each head/position
// Enhanced flash attention with state tensor for S/M values and accumulated numerator
// s_m_state: [2 + head_dim, n_heads * q_len] tensor containing [M, S, VKQ...] for each head/position
GGML_API struct ggml_tensor * ggml_flash_attn_ext_with_state(
struct ggml_context * ctx,
struct ggml_tensor * q,
Expand Down
55 changes: 32 additions & 23 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7174,19 +7174,19 @@ static void ggml_compute_forward_flash_attn_ext_f16_with_state(
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)

// Validate state tensor format: [2, n_heads * q_len]
GGML_ASSERT(state != NULL);
GGML_ASSERT(state->ne[0] == 2); // [M, S] pairs
GGML_ASSERT(state->ne[1] == neq2 * neq1); // n_heads * q_len
GGML_ASSERT(state->type == GGML_TYPE_F32);

const int ith = params->ith;
const int nth = params->nth;

const int64_t DK = nek0; //> head_dim
const int64_t DV = nev0; //> head_dim
const int64_t N = neq1; //> q_len

// Validate state tensor format: [2 + DV, n_heads * q_len]
GGML_ASSERT(state != NULL);
GGML_ASSERT(state->ne[0] == DV + 2); // [M, S, VKQ...]
GGML_ASSERT(state->ne[1] == neq2 * neq1); // n_heads * q_len
GGML_ASSERT(state->type == GGML_TYPE_F32);

GGML_ASSERT(ne0 == DV); //> dst -> ne[0] == head_dim
GGML_ASSERT(ne2 == N); //> dst -> ne[2] == q_len

Expand Down Expand Up @@ -7267,27 +7267,33 @@ static void ggml_compute_forward_flash_attn_ext_f16_with_state(

// Calculate state tensor offset for this head/position
const int64_t state_idx = iq2 * neq1 + iq1; // head * q_len + position
float * state_data = (float *)state->data;

// Read initial S and M values from state tensor
// State format: [M, S] for each head/position
float S = state_data[state_idx * 2 + 1]; // sum (index 1)
float M = state_data[state_idx * 2 + 0]; // maximum KQ value (index 0)

// If this is the first call (indicated by M == -INFINITY), initialize properly
if (M == -INFINITY) {
S = 0.0f;
}
float * state_data = (float *) state->data;
float * state_row = state_data + state_idx * (DV + 2);

// Read initial values
float M = state_row[0];
float S = state_row[1];

float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16

if (v->type == GGML_TYPE_F16) {
memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
for (int64_t d = 0; d < DV; ++d) {
float val = (M == -INFINITY) ? 0.0f : state_row[2 + d];
VKQ16[d] = ggml_fp32_to_fp16(val);
VKQ32[d] = val; // keep FP32 version for later
}
} else {
memset(VKQ32, 0, DV*sizeof(float));
for (int64_t d = 0; d < DV; ++d) {
VKQ32[d] = (M == -INFINITY) ? 0.0f : state_row[2 + d];
}
}

if (M == -INFINITY) {
S = 0.0f;
M = -INFINITY;
}

const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
Expand Down Expand Up @@ -7375,16 +7381,19 @@ static void ggml_compute_forward_flash_attn_ext_f16_with_state(
S = S*ms + vs; // scale and increment sum with partial sum
}

// Write updated S and M values back to state tensor
state_data[state_idx * 2 + 0] = M; // maximum KQ value (index 0)
state_data[state_idx * 2 + 1] = S; // sum (index 1)

if (v->type == GGML_TYPE_F16) {
for (int64_t d = 0; d < DV; ++d) {
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
}
}

// Write updated S, M and VKQ values back to state tensor (before normalization)
state_row[0] = M;
state_row[1] = S;
for (int64_t d = 0; d < DV; ++d) {
state_row[2 + d] = VKQ32[d];
}

// V /= S
const float S_inv = 1.0f / S;
ggml_vec_scale_f32(DV, VKQ32, S_inv);
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4616,9 +4616,9 @@ struct ggml_tensor * ggml_flash_attn_ext_with_state(
GGML_ASSERT(mask);
}

// Validate state tensor format: [2, n_heads * q_len]
// Validate state tensor format: [2 + head_dim, n_heads * q_len]
GGML_ASSERT(s_m_state != NULL);
GGML_ASSERT(s_m_state->ne[0] == 2); // [M, S] pairs
GGML_ASSERT(s_m_state->ne[0] == v->ne[0] + 2); // [M, S, VKQ...] per head/position
GGML_ASSERT(s_m_state->ne[1] == q->ne[2] * q->ne[1]); // n_heads * q_len
GGML_ASSERT(s_m_state->type == GGML_TYPE_F32);

Expand Down
29 changes: 18 additions & 11 deletions tests/test-flash-attn-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,16 @@ static float tensor_max_diff(ggml_tensor* a, ggml_tensor* b) {

static void reset_state_tensor(ggml_tensor* state) {
float* state_data = (float*)state->data;
size_t n_pairs = ggml_nelements(state) / 2;

for (size_t i = 0; i < n_pairs; i++) {
state_data[i * 2 + 0] = -INFINITY; // M (max KQ value)
state_data[i * 2 + 1] = 0.0f; // S (sum)
const int row = state->ne[0]; // 2 + head_dim
const int head_dim = row - 2;
const int n_pairs = state->ne[1];

for (int i = 0; i < n_pairs; i++) {
state_data[i * row + 0] = -INFINITY; // M
state_data[i * row + 1] = 0.0f; // S
for (int d = 0; d < head_dim; ++d) {
state_data[i * row + 2 + d] = 0.0f; // VKQ
}
}
}

Expand Down Expand Up @@ -145,8 +150,8 @@ int main() {
const int padded_seq_len = GGML_PAD(seq_len, GGML_KQ_MASK_PAD);
ggml_tensor * mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, padded_kv_len, padded_seq_len);

// Create state tensor: [2, n_heads * seq_len] for [M, S] pairs
ggml_tensor * state = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, n_heads * seq_len);
// Create state tensor: [2 + head_dim, n_heads * seq_len] for [M, S, VKQ]
ggml_tensor * state = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2 + head_dim, n_heads * seq_len);

print_tensor_info("Q", q);
print_tensor_info("K", k);
Expand Down Expand Up @@ -230,8 +235,9 @@ int main() {
// Print state before this segment
printf(" State before segment %d: ", seg + 1);
float* state_data = (float*)state->data;
int row = state->ne[0];
for (int i = 0; i < std::min(4, n_heads * seq_len); i++) {
printf("[M=%.3f,S=%.3f] ", state_data[i * 2 + 0], state_data[i * 2 + 1]);
printf("[M=%.3f,S=%.3f] ", state_data[i * row + 0], state_data[i * row + 1]);
}
printf("...\n");

Expand Down Expand Up @@ -321,7 +327,7 @@ int main() {
// Print state after this segment
printf(" State after segment %d: ", seg + 1);
for (int i = 0; i < std::min(4, n_heads * seq_len); i++) {
printf("[M=%.3f,S=%.3f] ", state_data[i * 2 + 0], state_data[i * 2 + 1]);
printf("[M=%.3f,S=%.3f] ", state_data[i * row + 0], state_data[i * row + 1]);
}
printf("...\n");

Expand Down Expand Up @@ -367,12 +373,13 @@ int main() {
print_f32_sample("Final state", state, 16);

float* state_data = (float*)state->data;
int row = state->ne[0];
float min_m = INFINITY, max_m = -INFINITY;
float min_s = INFINITY, max_s = -INFINITY;

for (int i = 0; i < n_heads * seq_len; i++) {
float m_val = state_data[i * 2 + 0];
float s_val = state_data[i * 2 + 1];
float m_val = state_data[i * row + 0];
float s_val = state_data[i * row + 1];

if (m_val != -INFINITY) {
min_m = std::min(min_m, m_val);
Expand Down