@@ -174,6 +174,7 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
174174 } else {
175175 continue ;
176176 }
177+
177178 if (cells[i].is_empty ()) {
178179 // keep count of the number of used cells
179180 if (cells[i].pos >= 0 ) {
@@ -340,6 +341,9 @@ void llama_kv_cache_unified::restore() {
340341 return ;
341342 }
342343
344+ // TODO: here we assume that all sequences should be removed from the cache which is not always the case
345+ // need to start keeping more detailed pending information per-sequence
346+
343347 uint32_t new_head = size;
344348
345349 for (auto & range : pending.ranges ) {
@@ -1374,6 +1378,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
13741378 LLAMA_LOG_ERROR (" %s: failed to find available cells in kv cache\n " , __func__);
13751379 return false ;
13761380 }
1381+
13771382 commit ();
13781383
13791384 // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
@@ -1569,9 +1574,10 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15691574
15701575 // TODO: provide from the llama_context
15711576 const uint32_t n_seq_max = 1 ;
1577+ const uint32_t n_batch = hparams.n_swa ;
15721578
15731579 const uint32_t kv_size_base = kv_size;
1574- const uint32_t kv_size_swa = hparams.n_swa *n_seq_max;
1580+ const uint32_t kv_size_swa = ( hparams.n_swa + n_batch) *n_seq_max;
15751581
15761582 kv_base = std::make_unique<llama_kv_cache_unified>(model, std::move (filter_base), type_k, type_v, v_trans, offload, kv_size_base, padding);
15771583 kv_swa = std::make_unique<llama_kv_cache_unified>(model, std::move (filter_swa), type_k, type_v, v_trans, offload, kv_size_swa, padding);
@@ -1621,6 +1627,21 @@ void llama_kv_cache_unified_iswa::restore() {
16211627}
16221628
16231629void llama_kv_cache_unified_iswa::commit () {
1630+ if (pending.pos_max .empty ()) {
1631+ return ;
1632+ }
1633+
1634+ // slide the window, forgetting old tokens
1635+ for (const auto & [seq_id, pos_max] : pending.pos_max ) {
1636+ if (pos_max <= (llama_pos) hparams.n_swa ) {
1637+ continue ;
1638+ }
1639+
1640+ kv_swa->seq_rm (seq_id, -1 , pos_max - hparams.n_swa );
1641+ }
1642+
1643+ pending.pos_max .clear ();
1644+
16241645 kv_base->commit ();
16251646 kv_swa ->commit ();
16261647}
@@ -1645,6 +1666,16 @@ void llama_kv_cache_unified_iswa::set_full() {
16451666}
16461667
16471668llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
1669+ // this will be used upon successful decode, during commit, to remove old SWA tokens
1670+ for (int i = 0 ; i < batch.n_tokens ; ++i) {
1671+ for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
1672+ const llama_seq_id seq_id = batch.seq_id [i][s];
1673+ const llama_pos pos = batch.pos [i];
1674+
1675+ pending.pos_max [seq_id] = std::max (pending.pos_max [seq_id], pos);
1676+ }
1677+ }
1678+
16481679 return kv_base->sbatch_init (batch, logits_all);
16491680}
16501681
0 commit comments