Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7253,7 +7253,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_with_state(
// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
const int iq3 = ir/(neq2*neq1);
const int iq3 = ir / (neq2*neq1);
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);

Expand Down Expand Up @@ -7425,6 +7425,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_with_state(
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
}
}

void ggml_compute_forward_flash_attn_ext_mixed(
const ggml_compute_params * params,
const ggml_tensor * q,
Expand Down
41 changes: 28 additions & 13 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,22 @@ llm_graph_input_attn_kv_mixed * llm_graph_context::build_attn_inp_kv_mixed() con
inp->self_kq_mask_quant_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_quant, GGML_TYPE_F16) : inp->self_kq_mask_quant;
}

// Create state and result tensors for stateful flash attention
{
// State tensor: [2, n_heads * seq_len] for [M, S] pairs
const auto seq_len = n_tokens; // sequence length for current batch
inp->attn_state = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2, n_head * seq_len);
ggml_set_input(inp->attn_state);
ggml_format_name(inp->attn_state, "attn_state");

// Result tensor: [head_dim, n_heads, seq_len, n_batch] - Fixed dimension order to match ggml_flash_attn_ext_with_state
const auto head_dim = n_embd_head_v;
const auto n_batch = 1;
inp->attn_result = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_dim, n_head, seq_len, n_batch);
ggml_set_input(inp->attn_result);
ggml_format_name(inp->attn_result, "attn_result");
}

return (llm_graph_input_attn_kv_mixed *) res->add_input(std::move(inp));
}

Expand Down Expand Up @@ -1754,6 +1770,8 @@ ggml_tensor * llm_graph_context::build_attn_mha_with_state(
ggml_tensor * kq_mask_fp16,
ggml_tensor * kq_mask_quant,
ggml_tensor * v_mla,
ggml_tensor * state,
ggml_tensor * result,
float kq_scale) const {

GGML_UNUSED(gf);
Expand Down Expand Up @@ -1845,16 +1863,11 @@ ggml_tensor * llm_graph_context::build_attn_mha_with_state(
const auto n_kv_fp16 = k_fp16->ne[1]; // number of keys/values
const auto n_kv_quant = k_quant->ne[1]; // number of keys/values

// TODO : Modify these tensors to be non-dynamic alloc.
// Create state tensor: [2, n_heads * seq_len] for [M, S] pairs
ggml_tensor * state = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2, n_head * seq_len);
ggml_set_input(state);
cb(state, "state", -1);
// Use pre-allocated state and result tensors to avoid dynamic allocation issues
GGML_ASSERT(state != nullptr && "State tensor must be pre-allocated");
GGML_ASSERT(result != nullptr && "Result tensor must be pre-allocated");

// Create output tensor
ggml_tensor * result = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32,
q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
ggml_set_input(result);
cb(state, "state", -1);
cb(result, "result", -1);

// Cast to F16 if needed for flash attention
Expand Down Expand Up @@ -1905,10 +1918,11 @@ ggml_tensor * llm_graph_context::build_attn_mha_with_state(
cb(result_quant, "attn_result_quant", -1);
cur = result;

// Reshape to final output format
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, seq_len);
// Reshape to final output format: [head_dim * n_heads, seq_len]
// cur has dimensions [head_dim, n_heads, seq_len, n_batch], so we flatten the first two dimensions
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]);
Comment on lines +1922 to +1923
Copy link

Copilot AI Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reshape only accounts for head_dim * n_heads by seq_len, dropping the n_batch dimension. The reshape dimensions must multiply to the total element count (including batch). Consider using cur->ne[0] * cur->ne[1] and cur->ne[2] * cur->ne[3] to preserve the batch size.

Suggested change
// cur has dimensions [head_dim, n_heads, seq_len, n_batch], so we flatten the first two dimensions
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]);
// cur has dimensions [head_dim, n_heads, seq_len, n_batch], so we flatten the first two dimensions while preserving the batch size
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2] * cur->ne[3]);

Copilot uses AI. Check for mistakes.

ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand(gf, result);

return cur;
}
Expand Down Expand Up @@ -1957,7 +1971,8 @@ ggml_tensor * llm_graph_context::build_attn_mixed_with_state(
// Use the segmented flash attention with state
ggml_tensor * cur = build_attn_mha_with_state(
gf, q_cur, k_fp16, v_fp16, k_quant, v_quant,
kq_b, kq_mask_fp16, kq_mask_quant, v_mla, kq_scale
kq_b, kq_mask_fp16, kq_mask_quant, v_mla,
inp->get_state(), inp->get_result(), kq_scale
);
cb(cur, "kqv_out", il);

Expand Down
8 changes: 8 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,17 @@ class llm_graph_input_attn_kv_mixed : public llm_graph_input_i {

ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * get_kq_mask_quant() const { return self_kq_mask_quant_cnv; }
ggml_tensor * get_state() const { return attn_state; }
ggml_tensor * get_result() const { return attn_result; }

ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
ggml_tensor * self_kq_mask_quant = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_quant_cnv = nullptr; // [n_kv, n_batch]

// State and result tensors for stateful flash attention
ggml_tensor * attn_state = nullptr; // F32 [2, n_heads * seq_len] for [M, S] pairs
ggml_tensor * attn_result = nullptr; // F32 [head_dim, seq_len, n_heads, n_batch] output tensor

const llama_hparams & hparams;
const llama_cparams & cparams;
Expand Down Expand Up @@ -654,6 +660,8 @@ struct llm_graph_context {
ggml_tensor * kq_mask_fp16,
ggml_tensor * kq_mask_quant,
ggml_tensor * v_mla,
ggml_tensor * state,
ggml_tensor * result,
float kq_scale) const;

llm_graph_input_attn_cross * build_attn_inp_cross() const;
Expand Down
87 changes: 58 additions & 29 deletions src/llama-kv-cache-mixed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ bool llama_kv_cache_mixed::find_slot(const llama_ubatch & ubatch) {
n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad))); //> Virtual head of kv cache.
n_quantized = std::min(size, std::max(n_pad, GGML_PAD(cell_max_quantized(), n_pad))); //> Virtual head of quantized kv cache.

LLAMA_LOG_INFO("\n[mixed-kv] successfully allocated slot: head=%u, used=%u, n=%u, n_quantized=%u, cell_max=%u, cell_max_quantized=%u\n", head, used, n, n_quantized, cell_max(), cell_max_quantized());
// LLAMA_LOG_INFO("\n[mixed-kv] successfully allocated slot: head=%u, used=%u, n=%u, n_quantized=%u, cell_max=%u, cell_max_quantized=%u\n", head, used, n, n_quantized, cell_max(), cell_max_quantized());

return true;
}
Expand All @@ -1029,16 +1029,45 @@ void llama_kv_cache_mixed::set_input_kq_mask(ggml_tensor * dst, const llama_ubat
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs;

GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
// TODO: Currently I can use following code to force pass the llama-cli, but this is simple NOT DO COMPUTE.

// Basic tensor validation
if (!dst) {
LLAMA_LOG_INFO("[mixed-kv] dst tensor is null in set_input_kq_mask, skipping\n");
return;
}

// Check if buffer and data are allocated
// This can happen during graph building phase before allocation
if (!dst->buffer || !dst->data) {
LLAMA_LOG_INFO("[mixed-kv] Buffer or data not allocated yet for %s mask, skipping mask setting\n",
is_quantized ? "quantized" : "FP16");
return;
}

// Verify buffer is host-accessible before proceeding
if (!ggml_backend_buffer_is_host(dst->buffer)) {
LLAMA_LOG_INFO("[mixed-kv] Buffer for %s mask is not host-accessible\n",
is_quantized ? "quantized" : "FP16");
return;
}

float * data = (float *) dst->data;

// Final safety check for data pointer
if (!data) {
LLAMA_LOG_INFO("[mixed-kv] Data pointer is null for %s mask tensor\n",
is_quantized ? "quantized" : "FP16");
return;
}

// Choose the correct KV length based on whether we're setting mask for quantized or FP16 part
// - For FP16 mask (is_quantized=false): use n (covers recent tokens)
// - For quantized mask (is_quantized=true): use n_quantized (covers older tokens)
const int64_t n_kv = is_quantized ? n_quantized : n;

LLAMA_LOG_DEBUG("[mixed-kv] Setting %s mask: n_kv=%ld (n=%u, n_quantized=%u)\n",
is_quantized ? "quantized" : "FP16", n_kv, n, n_quantized);
LLAMA_LOG_DEBUG("[mixed-kv] Setting %s mask: n_kv=%lld (n=%u, n_quantized=%u)\n",
is_quantized ? "quantized" : "FP16", (long long)n_kv, n, n_quantized);

// Use only the previous KV cells of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
Expand All @@ -1057,53 +1086,53 @@ void llama_kv_cache_mixed::set_input_kq_mask(ggml_tensor * dst, const llama_ubat
const llama_seq_id seq_id = ubatch->seq_id[s][0];

for (int j = 0; j < n_seq_tokens; ++j) {
// Current query token's position
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];

// Loop through all tokens in KV cache
for (int i = 0; i < n_kv; ++i) {
// Current key token's position
const llama_pos p0 = cells[i].pos; //> kv_cache idx.

bool masked = false;

// Rule 0: For mixed cache, check if cell belongs to the part we're masking
// - For FP16 mask: only include non-quantized cells
// - For quantized mask: only include quantized cells
// Get position from appropriate cell range
llama_pos p0;
if (is_quantized) {
masked = masked || (!cells[i].is_quantized()); // Skip non-quantized cells for quantized mask
// For quantized mask, use cells [0, n_quantized)
p0 = (i < cells.size()) ? cells[i].pos : -1;
} else {
masked = masked || (cells[i].is_quantized()); // Skip quantized cells for FP16 mask
// For FP16 mask, use cells [n_quantized, n_quantized + n)
uint32_t cell_idx = n_quantized + i;
p0 = (cell_idx < cells.size()) ? cells[cell_idx].pos : -1;
}

// Rule 1: If key token not in current query token's sequence, mask.
masked = masked || (!cells[i].has_seq_id(seq_id)); //> This cell is not in the current query token's sequence.
bool masked = false;

// Rule 2: If causal attention and key token after query token (future), mask.
masked = masked || (causal_attn && p0 > p1); //> p0 in SEQ_LEN > p1 in KV_LEN.
// mask if invalid cell
masked = masked || (p0 < 0);

// mask the token if not the same sequence
if (!masked && i < (int)cells.size()) {
uint32_t cell_idx = is_quantized ? i : (n_quantized + i);
if (cell_idx < cells.size()) {
masked = masked || (!cells[cell_idx].has_seq_id(seq_id));
}
}

// mask future tokens
masked = masked || (causal_attn && p0 > p1);

float f = 0.0f;

if (masked) {
// For masked tokens, set attention score to negative infinity
f = -INFINITY;
} else if (hparams.use_alibi) {
// Rule 3: If using ALiBi, compute penalty based on query-key distance
f = -std::abs(p0 - p1);
}

// Write computed mask value to destination tensor
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
}
}
}

// Rule 4: Mask padding tokens in batch (adapted for mixed KV cache)
if (data) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
// mask padded tokens
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
Expand Down