Skip to content

Commit 104e5a0

Browse files
committed
[feature] Add ggml-flash-attn with kv segment.
1 parent a525b86 commit 104e5a0

File tree

4 files changed

+122
-2
lines changed

4 files changed

+122
-2
lines changed

src/llama-graph.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,6 +1725,101 @@ ggml_tensor * llm_graph_context::build_attn(
17251725
return cur;
17261726
}
17271727

1728+
ggml_tensor * llm_graph_context::build_attn_mha_with_state(
1729+
ggml_cgraph * gf,
1730+
ggml_tensor * q,
1731+
ggml_tensor * k_fp16,
1732+
ggml_tensor * v_fp16,
1733+
ggml_tensor * k_quant,
1734+
ggml_tensor * v_quant,
1735+
ggml_tensor * kq_b,
1736+
ggml_tensor * kq_mask,
1737+
ggml_tensor * v_mla,
1738+
float kq_scale) const {
1739+
1740+
// Simplified approach: just use the FP16 part for now
1741+
// In practice, the mixed KV cache get_k/get_v should already return merged views
1742+
// We'll use the merged tensors that should already include both FP16 and dequantized data
1743+
1744+
ggml_tensor * k_to_use = nullptr;
1745+
ggml_tensor * v_to_use = nullptr;
1746+
1747+
// Prefer FP16 cache if available, otherwise use quantized
1748+
if (k_fp16 && v_fp16) {
1749+
k_to_use = k_fp16;
1750+
v_to_use = v_fp16;
1751+
} else if (k_quant && v_quant) {
1752+
k_to_use = k_quant;
1753+
v_to_use = v_quant;
1754+
} else {
1755+
GGML_ABORT("No valid KV cache found");
1756+
}
1757+
1758+
cb(k_to_use, "k_to_use", -1);
1759+
cb(v_to_use, "v_to_use", -1);
1760+
1761+
// Use standard build_attn_mha with the available KV cache
1762+
ggml_tensor * cur = build_attn_mha(gf, q, k_to_use, v_to_use, kq_b, kq_mask, v_mla, kq_scale);
1763+
1764+
return cur;
1765+
}
1766+
1767+
ggml_tensor * llm_graph_context::build_attn_mixed_with_state(
1768+
llm_graph_input_attn_kv_mixed * inp,
1769+
ggml_cgraph * gf,
1770+
ggml_tensor * wo,
1771+
ggml_tensor * wo_b,
1772+
ggml_tensor * q_cur,
1773+
ggml_tensor * k_cur,
1774+
ggml_tensor * v_cur,
1775+
ggml_tensor * kq_b,
1776+
ggml_tensor * v_mla,
1777+
float kq_scale,
1778+
int il) const {
1779+
1780+
// these nodes are added to the graph together so that they are not reordered
1781+
// by doing so, the number of splits in the graph is reduced
1782+
ggml_build_forward_expand(gf, q_cur);
1783+
ggml_build_forward_expand(gf, k_cur);
1784+
ggml_build_forward_expand(gf, v_cur);
1785+
1786+
const llama_kv_cache_mixed * kv_self = static_cast<const llama_kv_cache_mixed *>(memory);
1787+
1788+
{
1789+
// store to KV cache
1790+
ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
1791+
ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
1792+
}
1793+
1794+
const auto & kq_mask = inp->get_kq_mask();
1795+
cb(kq_mask, "KQ_mask", il);
1796+
1797+
// Get FP16 KV cache
1798+
ggml_tensor * k_fp16 = kv_self->get_k(ctx0, il);
1799+
ggml_tensor * v_fp16 = kv_self->get_v(ctx0, il);
1800+
1801+
// Get quantized KV cache
1802+
ggml_tensor * k_quant = kv_self->get_k_quant(ctx0, il);
1803+
ggml_tensor * v_quant = kv_self->get_v_quant(ctx0, il);
1804+
1805+
// Use the new mixed attention with state
1806+
ggml_tensor * cur = build_attn_mha_with_state(
1807+
gf, q_cur, k_fp16, v_fp16, k_quant, v_quant,
1808+
kq_b, kq_mask, v_mla, kq_scale
1809+
);
1810+
cb(cur, "kqv_out", il);
1811+
1812+
if (wo) {
1813+
cur = build_lora_mm(wo, cur);
1814+
}
1815+
1816+
if (wo_b) {
1817+
cur = ggml_add(ctx0, cur, wo_b);
1818+
}
1819+
1820+
return cur;
1821+
}
1822+
17281823
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
17291824
// TODO move to hparams if a T5 variant appears that uses a different value
17301825
const int64_t max_distance = 128;

src/llama-graph.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,31 @@ struct llm_graph_context {
627627
float kq_scale,
628628
int il) const;
629629

630+
ggml_tensor * build_attn_mixed_with_state(
631+
llm_graph_input_attn_kv_mixed * inp,
632+
ggml_cgraph * gf,
633+
ggml_tensor * wo,
634+
ggml_tensor * wo_b,
635+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
636+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
637+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
638+
ggml_tensor * kq_b,
639+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
640+
float kq_scale,
641+
int il) const;
642+
643+
ggml_tensor * build_attn_mha_with_state(
644+
ggml_cgraph * gf,
645+
ggml_tensor * q,
646+
ggml_tensor * k_fp16,
647+
ggml_tensor * v_fp16,
648+
ggml_tensor * k_quant,
649+
ggml_tensor * v_quant,
650+
ggml_tensor * kq_b,
651+
ggml_tensor * kq_mask,
652+
ggml_tensor * v_mla,
653+
float kq_scale) const;
654+
630655
llm_graph_input_attn_cross * build_attn_inp_cross() const;
631656

632657
ggml_tensor * build_attn(

src/llama-kv-cache-mixed.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ bool llama_kv_cache_mixed::find_slot(const llama_ubatch & ubatch) {
10031003
n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad))); //> Virtual head of kv cache.
10041004
n_quantized = std::min(size, std::max(n_pad, GGML_PAD(cell_max_quantized(), n_pad))); //> Virtual head of quantized kv cache.
10051005

1006-
LLAMA_LOG_INFO("\n[mixed-kv] successfully allocated slot: head=%u, used=%u, n=%u, n_quantized=%u, cell_max=%u, cell_max_quantized=%u\n", head, used, n, n_quantized, cell_max(), cell_max_quantized());
1006+
// LLAMA_LOG_INFO("\n[mixed-kv] successfully allocated slot: head=%u, used=%u, n=%u, n_quantized=%u, cell_max=%u, cell_max_quantized=%u\n", head, used, n, n_quantized, cell_max(), cell_max_quantized());
10071007

10081008
return true;
10091009
}

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4622,7 +4622,7 @@ struct llm_build_llama : public llm_graph_context {
46224622
cb(Vcur, "Vcur", il);
46234623

46244624
if (dynamic_cast<const llama_kv_cache_mixed*>(memory)) {
4625-
cur = build_attn(static_cast<llm_graph_input_attn_kv_mixed*>(inp_attn), gf,
4625+
cur = build_attn_mixed_with_state(static_cast<llm_graph_input_attn_kv_mixed*>(inp_attn), gf,
46264626
model.layers[il].wo, model.layers[il].bo,
46274627
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
46284628
} else {

0 commit comments

Comments
 (0)