@@ -906,6 +906,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
906906 const int64_t n_head_kv = hparams.n_head_kv (il);
907907 const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa (il);
908908
909+ // TODO: move to model.get_freq_factors()
909910 const bool is_swa = hparams.is_swa (il);
910911
911912 // note: the swa rope params could become part of the cparams in the future
@@ -1600,6 +1601,181 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
16001601 return true ;
16011602}
16021603
1604+ //
1605+ // llama_kv_cache_unified_iswa
1606+ //
1607+
1608+ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa (
1609+ const llama_model & model,
1610+ ggml_type type_k,
1611+ ggml_type type_v,
1612+ bool v_trans,
1613+ bool offload,
1614+ uint32_t kv_size,
1615+ uint32_t padding) : hparams(model.hparams) {
1616+ llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
1617+ llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
1618+
1619+ // TODO: provide from the llama_context
1620+ const uint32_t n_seq_max = 1 ;
1621+
1622+ const uint32_t kv_size_base = kv_size;
1623+ const uint32_t kv_size_swa = hparams.n_swa *n_seq_max;
1624+
1625+ kv_base = std::make_unique<llama_kv_cache_unified>(model, std::move (filter_base), type_k, type_v, v_trans, offload, kv_size_base, padding);
1626+ kv_swa = std::make_unique<llama_kv_cache_unified>(model, std::move (filter_swa), type_k, type_v, v_trans, offload, kv_size_swa, padding);
1627+ }
1628+
1629+ void llama_kv_cache_unified_iswa::clear () {
1630+ kv_base->clear ();
1631+ kv_swa ->clear ();
1632+ }
1633+
1634+ bool llama_kv_cache_unified_iswa::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1635+ bool res = true ;
1636+
1637+ res = res & kv_base->seq_rm (seq_id, p0, p1);
1638+ res = res & kv_swa ->seq_rm (seq_id, p0, p1);
1639+
1640+ return res;
1641+ }
1642+
1643+ void llama_kv_cache_unified_iswa::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
1644+ kv_base->seq_cp (seq_id_src, seq_id_dst, p0, p1);
1645+ kv_swa ->seq_cp (seq_id_src, seq_id_dst, p0, p1);
1646+ }
1647+
1648+ void llama_kv_cache_unified_iswa::seq_keep (llama_seq_id seq_id) {
1649+ kv_base->seq_keep (seq_id);
1650+ kv_swa ->seq_keep (seq_id);
1651+ }
1652+
1653+ void llama_kv_cache_unified_iswa::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
1654+ kv_base->seq_add (seq_id, p0, p1, delta);
1655+ kv_swa ->seq_add (seq_id, p0, p1, delta);
1656+ }
1657+
1658+ void llama_kv_cache_unified_iswa::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
1659+ kv_base->seq_div (seq_id, p0, p1, d);
1660+ kv_swa ->seq_div (seq_id, p0, p1, d);
1661+ }
1662+
1663+ llama_pos llama_kv_cache_unified_iswa::seq_pos_max (llama_seq_id seq_id) const {
1664+ return kv_base->seq_pos_max (seq_id);
1665+ }
1666+
1667+ void llama_kv_cache_unified_iswa::restore () {
1668+ kv_base->restore ();
1669+ kv_swa ->restore ();
1670+ }
1671+
1672+ void llama_kv_cache_unified_iswa::commit () {
1673+ kv_base->commit ();
1674+ kv_swa ->commit ();
1675+ }
1676+
1677+ bool llama_kv_cache_unified_iswa::update (llama_context & lctx) {
1678+ bool res = true ;
1679+
1680+ res = res & kv_base->update (lctx);
1681+ res = res & kv_swa ->update (lctx);
1682+
1683+ return res;
1684+ }
1685+
1686+ void llama_kv_cache_unified_iswa::defrag_sched (float thold) {
1687+ kv_base->defrag_sched (thold);
1688+ kv_swa ->defrag_sched (thold);
1689+ }
1690+
1691+ void llama_kv_cache_unified_iswa::set_full () {
1692+ kv_base->set_full ();
1693+ kv_swa ->set_full ();
1694+ }
1695+
1696+ llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
1697+ return kv_base->sbatch_init (batch, logits_all);
1698+ }
1699+
1700+ llama_ubatch llama_kv_cache_unified_iswa::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1701+ return kv_base->ubatch_next (sbatch, n_ubatch, embd_pooled);
1702+ }
1703+
1704+ bool llama_kv_cache_unified_iswa::find_slot (const llama_ubatch & batch) {
1705+ bool res = true ;
1706+
1707+ res = res & kv_base->find_slot (batch);
1708+ res = res & kv_swa ->find_slot (batch);
1709+
1710+ return res;
1711+ }
1712+
1713+ int32_t llama_kv_cache_unified_iswa::get_n_tokens () const {
1714+ return kv_base->get_n_tokens ();
1715+ }
1716+
1717+ int32_t llama_kv_cache_unified_iswa::get_used_cells () const {
1718+ return kv_base->get_used_cells ();
1719+ }
1720+
1721+ llama_pos llama_kv_cache_unified_iswa::get_pos_max () const {
1722+ return kv_base->get_pos_max ();
1723+ }
1724+
1725+ bool llama_kv_cache_unified_iswa::get_can_shift () const {
1726+ return false ;
1727+ }
1728+
1729+ void llama_kv_cache_unified_iswa::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
1730+ kv_base->state_write (io, seq_id);
1731+ kv_swa ->state_write (io, seq_id);
1732+ }
1733+
1734+ void llama_kv_cache_unified_iswa::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
1735+ kv_base->state_read (io, seq_id);
1736+ kv_swa ->state_read (io, seq_id);
1737+ }
1738+
1739+ ggml_tensor * llama_kv_cache_unified_iswa::get_k (ggml_context * ctx, int32_t il) const {
1740+ if (hparams.is_swa (il)) {
1741+ return kv_swa->get_k (ctx, il);
1742+ }
1743+
1744+ return kv_base->get_k (ctx, il);
1745+ }
1746+
1747+ ggml_tensor * llama_kv_cache_unified_iswa::get_v (ggml_context * ctx, int32_t il) const {
1748+ if (hparams.is_swa (il)) {
1749+ return kv_swa->get_v (ctx, il);
1750+ }
1751+
1752+ return kv_base->get_v (ctx, il);
1753+ }
1754+
1755+ ggml_tensor * llama_kv_cache_unified_iswa::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1756+ if (hparams.is_swa (il)) {
1757+ return kv_swa->cpy_k (ctx, k_cur, il);
1758+ }
1759+
1760+ return kv_base->cpy_k (ctx, k_cur, il);
1761+ }
1762+
1763+ ggml_tensor * llama_kv_cache_unified_iswa::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1764+ if (hparams.is_swa (il)) {
1765+ return kv_swa->cpy_v (ctx, v_cur, il);
1766+ }
1767+
1768+ return kv_base->cpy_v (ctx, v_cur, il);
1769+ }
1770+
1771+ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_base () const {
1772+ return kv_base.get ();
1773+ }
1774+
1775+ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa () const {
1776+ return kv_swa.get ();
1777+ }
1778+
16031779//
16041780// llama_kv_cache_recurrent
16051781//
0 commit comments