Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,22 @@ llm_graph_input_attn_kv_mixed * llm_graph_context::build_attn_inp_kv_mixed() con
inp->self_kq_mask_quant_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_quant, GGML_TYPE_F16) : inp->self_kq_mask_quant;
}

// Create state and result tensors for stateful flash attention
{
// State tensor: [2, n_heads * seq_len] for [M, S] pairs
const auto seq_len = n_tokens; // sequence length for current batch
inp->attn_state = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2, n_head * seq_len);
ggml_set_input(inp->attn_state);
ggml_format_name(inp->attn_state, "attn_state");

// Result tensor: [head_dim, n_heads, seq_len, n_batch] - Fixed dimension order to match ggml_flash_attn_ext_with_state
const auto head_dim = n_embd_head_v;
const auto n_batch = ubatch.n_seqs;
inp->attn_result = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_dim, n_head, seq_len, n_batch);
ggml_set_input(inp->attn_result);
ggml_format_name(inp->attn_result, "attn_result");
}

return (llm_graph_input_attn_kv_mixed *) res->add_input(std::move(inp));
}

Expand Down Expand Up @@ -1754,6 +1770,8 @@ ggml_tensor * llm_graph_context::build_attn_mha_with_state(
ggml_tensor * kq_mask_fp16,
ggml_tensor * kq_mask_quant,
ggml_tensor * v_mla,
ggml_tensor * state,
ggml_tensor * result,
float kq_scale) const {

GGML_UNUSED(gf);
Expand Down Expand Up @@ -1845,16 +1863,11 @@ ggml_tensor * llm_graph_context::build_attn_mha_with_state(
const auto n_kv_fp16 = k_fp16->ne[1]; // number of keys/values
const auto n_kv_quant = k_quant->ne[1]; // number of keys/values

// TODO : Modify these tensors to be non-dynamic alloc.
// Create state tensor: [2, n_heads * seq_len] for [M, S] pairs
ggml_tensor * state = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2, n_head * seq_len);
ggml_set_input(state);
cb(state, "state", -1);
// Use pre-allocated state and result tensors to avoid dynamic allocation issues
GGML_ASSERT(state != nullptr && "State tensor must be pre-allocated");
GGML_ASSERT(result != nullptr && "Result tensor must be pre-allocated");

// Create output tensor
ggml_tensor * result = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32,
q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
ggml_set_input(result);
cb(state, "state", -1);
cb(result, "result", -1);

// Cast to F16 if needed for flash attention
Expand Down Expand Up @@ -1905,8 +1918,9 @@ ggml_tensor * llm_graph_context::build_attn_mha_with_state(
cb(result_quant, "attn_result_quant", -1);
cur = result;

// Reshape to final output format
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, seq_len);
// Reshape to final output format: [head_dim * n_heads, seq_len]
// cur has dimensions [head_dim, n_heads, seq_len, n_batch], so we flatten the first two dimensions
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]);
Comment on lines +1922 to +1923
Copy link

Copilot AI Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reshape only accounts for head_dim * n_heads by seq_len, dropping the n_batch dimension. The reshape dimensions must multiply to the total element count (including batch). Consider using cur->ne[0] * cur->ne[1] and cur->ne[2] * cur->ne[3] to preserve the batch size.

Suggested change
// cur has dimensions [head_dim, n_heads, seq_len, n_batch], so we flatten the first two dimensions
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]);
// cur has dimensions [head_dim, n_heads, seq_len, n_batch], so we flatten the first two dimensions while preserving the batch size
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2] * cur->ne[3]);

Copilot uses AI. Check for mistakes.

ggml_build_forward_expand(gf, cur);

Expand Down Expand Up @@ -1957,7 +1971,8 @@ ggml_tensor * llm_graph_context::build_attn_mixed_with_state(
// Use the segmented flash attention with state
ggml_tensor * cur = build_attn_mha_with_state(
gf, q_cur, k_fp16, v_fp16, k_quant, v_quant,
kq_b, kq_mask_fp16, kq_mask_quant, v_mla, kq_scale
kq_b, kq_mask_fp16, kq_mask_quant, v_mla,
inp->get_state(), inp->get_result(), kq_scale
);
cb(cur, "kqv_out", il);

Expand Down
8 changes: 8 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,17 @@ class llm_graph_input_attn_kv_mixed : public llm_graph_input_i {

ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_quant() const { return self_kq_mask_quant_cnv; }
ggml_tensor * get_state() const { return attn_state; }
ggml_tensor * get_result() const { return attn_result; }

ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask_quant = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_quant_cnv = nullptr; // [n_kv, n_batch]

// State and result tensors for stateful flash attention
ggml_tensor * attn_state = nullptr; // F32 [2, n_heads * seq_len] for [M, S] pairs
ggml_tensor * attn_result = nullptr; // F32 [head_dim, seq_len, n_heads, n_batch] output tensor

const llama_hparams & hparams;
const llama_cparams & cparams;
Expand Down Expand Up @@ -654,6 +660,8 @@ struct llm_graph_context {
ggml_tensor * kq_mask_fp16,
ggml_tensor * kq_mask_quant,
ggml_tensor * v_mla,
ggml_tensor * state,
ggml_tensor * result,
float kq_scale) const;

llm_graph_input_attn_cross * build_attn_inp_cross() const;
Expand Down