@@ -1656,27 +1656,38 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
16561656 bool v_trans,
16571657 bool offload,
16581658 uint32_t kv_size,
1659+ bool swa_full,
16591660 uint32_t n_seq_max,
16601661 uint32_t n_batch,
16611662 uint32_t padding) : hparams(model.hparams) {
16621663 llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
16631664 llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
16641665
1665- const uint32_t kv_size_base = kv_size;
1666- const uint32_t kv_size_swa = std::min (kv_size, GGML_PAD (hparams.n_swa *n_seq_max + n_batch, padding));
1666+ const uint32_t size_base = kv_size;
16671667
1668- LLAMA_LOG_INFO (" %s: creating non-SWA KV cache, size = %u cells\n " , __func__, kv_size_base);
1668+ uint32_t size_swa = std::min (size_base, GGML_PAD (hparams.n_swa *n_seq_max + n_batch, padding));
1669+
1670+ // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
1671+ if (swa_full) {
1672+ LLAMA_LOG_WARN (" %s: using full-size SWA cache (ref: %s)\n " ,
1673+ __func__, " https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055" );
1674+
1675+ size_swa = size_base;
1676+ do_prune = false ;
1677+ }
1678+
1679+ LLAMA_LOG_INFO (" %s: creating non-SWA KV cache, size = %u cells\n " , __func__, size_base);
16691680
16701681 kv_base = std::make_unique<llama_kv_cache_unified>(
16711682 model, std::move (filter_base), type_k, type_v,
1672- v_trans, offload, kv_size_base , padding,
1683+ v_trans, offload, size_base , padding,
16731684 0 , LLAMA_SWA_TYPE_NONE);
16741685
1675- LLAMA_LOG_INFO (" %s: creating SWA KV cache, size = %u cells\n " , __func__, kv_size_swa );
1686+ LLAMA_LOG_INFO (" %s: creating SWA KV cache, size = %u cells\n " , __func__, size_swa );
16761687
16771688 kv_swa = std::make_unique<llama_kv_cache_unified>(
16781689 model, std::move (filter_swa), type_k, type_v,
1679- v_trans, offload, kv_size_swa, padding,
1690+ v_trans, offload, size_swa, padding,
16801691 hparams.n_swa , hparams.swa_type );
16811692}
16821693
@@ -1733,8 +1744,11 @@ void llama_kv_cache_unified_iswa::commit() {
17331744 kv_swa ->commit ();
17341745
17351746 // slide the attention window, forgetting/pruning old tokens that are outside the window
1736- for (const auto & [seq_id, entry] : pending.pos ) {
1737- kv_swa->prune_swa (seq_id, entry.pmin , entry.pmax );
1747+ if (do_prune) {
1748+ for (const auto & [seq_id, entry] : pending.pos ) {
1749+ kv_swa->prune_swa (seq_id, entry.pmin , entry.pmax );
1750+ }
1751+
17381752 }
17391753
17401754 pending.clear ();
@@ -1762,17 +1776,19 @@ void llama_kv_cache_unified_iswa::set_full() {
17621776llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
17631777 pending.clear ();
17641778
1765- for (int i = 0 ; i < batch.n_tokens ; ++i) {
1766- for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
1767- const llama_seq_id seq_id = batch.seq_id [i][s];
1768- const llama_pos pos = batch.pos [i];
1779+ if (do_prune) {
1780+ for (int i = 0 ; i < batch.n_tokens ; ++i) {
1781+ for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
1782+ const llama_seq_id seq_id = batch.seq_id [i][s];
1783+ const llama_pos pos = batch.pos [i];
17691784
1770- if (pending.pos .find (seq_id) == pending.pos .end ()) {
1771- pending.pos [seq_id].pmin = pos;
1772- pending.pos [seq_id].pmax = pos;
1773- } else {
1774- pending.pos [seq_id].pmin = std::min (pending.pos [seq_id].pmin , pos);
1775- pending.pos [seq_id].pmax = std::max (pending.pos [seq_id].pmax , pos);
1785+ if (pending.pos .find (seq_id) == pending.pos .end ()) {
1786+ pending.pos [seq_id].pmin = pos;
1787+ pending.pos [seq_id].pmax = pos;
1788+ } else {
1789+ pending.pos [seq_id].pmin = std::min (pending.pos [seq_id].pmin , pos);
1790+ pending.pos [seq_id].pmax = std::max (pending.pos [seq_id].pmax , pos);
1791+ }
17761792 }
17771793 }
17781794 }
0 commit comments