@@ -1725,6 +1725,101 @@ ggml_tensor * llm_graph_context::build_attn(
17251725 return cur;
17261726}
17271727
1728+ ggml_tensor * llm_graph_context::build_attn_mha_with_state (
1729+ ggml_cgraph * gf,
1730+ ggml_tensor * q,
1731+ ggml_tensor * k_fp16,
1732+ ggml_tensor * v_fp16,
1733+ ggml_tensor * k_quant,
1734+ ggml_tensor * v_quant,
1735+ ggml_tensor * kq_b,
1736+ ggml_tensor * kq_mask,
1737+ ggml_tensor * v_mla,
1738+ float kq_scale) const {
1739+
1740+ // Simplified approach: just use the FP16 part for now
1741+ // In practice, the mixed KV cache get_k/get_v should already return merged views
1742+ // We'll use the merged tensors that should already include both FP16 and dequantized data
1743+
1744+ ggml_tensor * k_to_use = nullptr ;
1745+ ggml_tensor * v_to_use = nullptr ;
1746+
1747+ // Prefer FP16 cache if available, otherwise use quantized
1748+ if (k_fp16 && v_fp16) {
1749+ k_to_use = k_fp16;
1750+ v_to_use = v_fp16;
1751+ } else if (k_quant && v_quant) {
1752+ k_to_use = k_quant;
1753+ v_to_use = v_quant;
1754+ } else {
1755+ GGML_ABORT (" No valid KV cache found" );
1756+ }
1757+
1758+ cb (k_to_use, " k_to_use" , -1 );
1759+ cb (v_to_use, " v_to_use" , -1 );
1760+
1761+ // Use standard build_attn_mha with the available KV cache
1762+ ggml_tensor * cur = build_attn_mha (gf, q, k_to_use, v_to_use, kq_b, kq_mask, v_mla, kq_scale);
1763+
1764+ return cur;
1765+ }
1766+
1767+ ggml_tensor * llm_graph_context::build_attn_mixed_with_state (
1768+ llm_graph_input_attn_kv_mixed * inp,
1769+ ggml_cgraph * gf,
1770+ ggml_tensor * wo,
1771+ ggml_tensor * wo_b,
1772+ ggml_tensor * q_cur,
1773+ ggml_tensor * k_cur,
1774+ ggml_tensor * v_cur,
1775+ ggml_tensor * kq_b,
1776+ ggml_tensor * v_mla,
1777+ float kq_scale,
1778+ int il) const {
1779+
1780+ // these nodes are added to the graph together so that they are not reordered
1781+ // by doing so, the number of splits in the graph is reduced
1782+ ggml_build_forward_expand (gf, q_cur);
1783+ ggml_build_forward_expand (gf, k_cur);
1784+ ggml_build_forward_expand (gf, v_cur);
1785+
1786+ const llama_kv_cache_mixed * kv_self = static_cast <const llama_kv_cache_mixed *>(memory);
1787+
1788+ {
1789+ // store to KV cache
1790+ ggml_build_forward_expand (gf, kv_self->cpy_k (ctx0, k_cur, il));
1791+ ggml_build_forward_expand (gf, kv_self->cpy_v (ctx0, v_cur, il));
1792+ }
1793+
1794+ const auto & kq_mask = inp->get_kq_mask ();
1795+ cb (kq_mask, " KQ_mask" , il);
1796+
1797+ // Get FP16 KV cache
1798+ ggml_tensor * k_fp16 = kv_self->get_k (ctx0, il);
1799+ ggml_tensor * v_fp16 = kv_self->get_v (ctx0, il);
1800+
1801+ // Get quantized KV cache
1802+ ggml_tensor * k_quant = kv_self->get_k_quant (ctx0, il);
1803+ ggml_tensor * v_quant = kv_self->get_v_quant (ctx0, il);
1804+
1805+ // Use the new mixed attention with state
1806+ ggml_tensor * cur = build_attn_mha_with_state (
1807+ gf, q_cur, k_fp16, v_fp16, k_quant, v_quant,
1808+ kq_b, kq_mask, v_mla, kq_scale
1809+ );
1810+ cb (cur, " kqv_out" , il);
1811+
1812+ if (wo) {
1813+ cur = build_lora_mm (wo, cur);
1814+ }
1815+
1816+ if (wo_b) {
1817+ cur = ggml_add (ctx0, cur, wo_b);
1818+ }
1819+
1820+ return cur;
1821+ }
1822+
17281823int32_t llama_relative_position_bucket (llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
17291824 // TODO move to hparams if a T5 variant appears that uses a different value
17301825 const int64_t max_distance = 128 ;
0 commit comments