@@ -333,44 +333,33 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
333333}
334334
335335void llama_kv_cache_unified::restore () {
336- if (pending.ubatches .empty ()) {
337- return ;
338- }
339-
340- uint32_t new_head = size;
341-
342- for (const auto & ubatch : pending.ubatches ) {
343- for (uint32_t i = 0 ; i < ubatch.data .n_tokens ; ++i) {
344- for (int s = 0 ; s < ubatch.data .n_seq_id [i]; ++s) {
345- const llama_seq_id seq_id = ubatch.data .seq_id [i][s];
336+ for (const auto & [id, cell] : recovery.cells ) {
337+ // TODO: move to new `struct kv_cells`
338+ const bool is_empty0 = cells[id].is_empty ();
339+ const bool is_empty1 = cell.is_empty ();
346340
347- cells[ubatch.head + i].seq_id .erase (seq_id);
348- if (cells[ubatch.head + i].seq_id .empty ()) {
349- used--;
350-
351- new_head = std::min (new_head, ubatch.head + i);
352- }
353-
354- cells[ubatch.head + i].pos = -1 ;
355- }
341+ if (!is_empty0 && is_empty1) {
342+ used--;
343+ } else if (is_empty0 && !is_empty1) {
344+ used++;
356345 }
357- }
358346
359- if (new_head != size && new_head < head) {
360- head = new_head;
347+ cells[id] = cell;
361348 }
362349
363- pending.clear ();
350+ recovery.clear ();
351+ states.clear ();
364352}
365353
366354void llama_kv_cache_unified::commit () {
367- if (pending. ubatches .empty ()) {
368- LLAMA_LOG_WARN (" %s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n " ,
369- __func__, " https://github.com/ggml-org/llama.cpp/pull/12695 " );
355+ if (recovery. cells .empty ()) {
356+ LLAMA_LOG_WARN (" %s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n " ,
357+ __func__, " https://github.com/ggml-org/llama.cpp/pull/13194 " );
370358 return ;
371359 }
372360
373- pending.clear ();
361+ recovery.clear ();
362+ states.clear ();
374363}
375364
376365bool llama_kv_cache_unified::update (llama_context & lctx) {
@@ -460,16 +449,11 @@ void llama_kv_cache_unified::set_full() {
460449 head = 0 ;
461450}
462451
463- llama_sbatch llama_kv_cache_unified::sbatch_init (
464- const llama_batch & batch,
465- bool logits_all) {
452+ llama_sbatch llama_kv_cache_unified::sbatch_init (const llama_batch & batch, bool logits_all) {
466453 return llama_sbatch (batch, hparams.n_embd , true , logits_all);
467454}
468455
469- llama_ubatch llama_kv_cache_unified::ubatch_next (
470- llama_sbatch & sbatch,
471- uint32_t n_ubatch,
472- bool embd_pooled) const {
456+ llama_ubatch llama_kv_cache_unified::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
473457 GGML_UNUSED (embd_pooled);
474458 return sbatch.split_simple (n_ubatch);
475459}
@@ -490,6 +474,29 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
490474 return false ;
491475 }
492476
477+ // #define FIND_SLOT_DEBUG 1
478+ #if FIND_SLOT_DEBUG
479+ LLAMA_LOG_WARN (" begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n " , n, used, head, n_swa);
480+
481+ // for debugging
482+ {
483+ std::string ss;
484+ if (n_swa > 0 ) {
485+ for (uint32_t i = 0 ; i < size; ++i) {
486+ if (cells[i].pos == -1 ) {
487+ ss += ' .' ;
488+ } else {
489+ ss += std::to_string (*cells[i].seq_id .begin ());
490+ }
491+ if (i%256 == 255 ) {
492+ ss += ' \n ' ;
493+ }
494+ }
495+ }
496+ LLAMA_LOG_WARN (" \n %s\n " , ss.c_str ());
497+ }
498+ #endif
499+
493500 uint32_t n_tested = 0 ;
494501
495502 while (true ) {
@@ -520,6 +527,11 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
520527 }
521528
522529 for (uint32_t i = 0 ; i < n_tokens; ++i) {
530+ // remember the original state
531+ if (recovery.cells .find (head + i) == recovery.cells .end ()) {
532+ recovery.cells [head + i] = cells[head + i];
533+ }
534+
523535 cells[head + i].pos = ubatch.pos [i];
524536
525537 for (int32_t j = 0 ; j < ubatch.n_seq_id [i]; j++) {
@@ -529,18 +541,25 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
529541
530542 used += n_tokens;
531543
532- pending.ubatches .push_back ({ head, ubatch });
533-
534544 // a heuristic, to avoid attending the full cache if it is not yet utilized
535545 // after enough generations, the benefit from this heuristic disappears
536546 // if we start defragmenting the cache, the benefit from this will be more important
537547 n = std::min (size, std::max (padding, GGML_PAD (cell_max (), padding)));
538548
539- // printf("n = %5d, used = %5d, head = %5d\n", n, used, head);
549+ states.push_back ({head, n});
550+
551+ #ifdef FIND_SLOT_DEBUG
552+ LLAMA_LOG_WARN (" end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n " , n, used, head, n_swa);
553+ #endif
540554
541555 return true ;
542556}
543557
558+ void llama_kv_cache_unified::set_state (int i) {
559+ head = states[i].head ;
560+ n = states[i].n ;
561+ }
562+
544563int32_t llama_kv_cache_unified::get_n_tokens () const {
545564 int32_t result = 0 ;
546565
@@ -642,6 +661,34 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
642661 return ggml_cpy (ctx, v_cur, v_view);
643662}
644663
664+ void llama_kv_cache_unified::prune_swa (llama_seq_id seq_id, llama_pos p1) {
665+ // no pruning is needed when the cache does not use SWA
666+ GGML_ASSERT (swa_type != LLAMA_SWA_TYPE_NONE && " do not prune non-SWA cache" );
667+
668+ for (uint32_t i = 0 ; i < size; ++i) {
669+ const llama_pos p0 = cells[i].pos ;
670+
671+ if (is_masked_swa (p0, p1)) {
672+ if (seq_id < 0 ) {
673+ cells[i].seq_id .clear ();
674+ } else if (cells[i].has_seq_id (seq_id)) {
675+ cells[i].seq_id .erase (seq_id);
676+ } else {
677+ continue ;
678+ }
679+
680+ if (cells[i].is_empty ()) {
681+ // keep count of the number of used cells
682+ if (cells[i].pos >= 0 ) {
683+ used--;
684+ }
685+
686+ cells[i].pos = -1 ;
687+ }
688+ }
689+ }
690+ }
691+
645692void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
646693 const int64_t n_tokens = ubatch->n_tokens ;
647694 const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
@@ -1181,6 +1228,10 @@ uint32_t llama_kv_cache_unified::cell_max() const {
11811228}
11821229
11831230bool llama_kv_cache_unified::is_masked_swa (llama_pos p0, llama_pos p1) const {
1231+ if (p0 < 0 ) {
1232+ return true ;
1233+ }
1234+
11841235 switch (swa_type) {
11851236 case LLAMA_SWA_TYPE_NONE:
11861237 {
@@ -1653,26 +1704,20 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
16531704void llama_kv_cache_unified_iswa::restore () {
16541705 kv_base->restore ();
16551706 kv_swa ->restore ();
1707+ states.clear ();
16561708}
16571709
16581710void llama_kv_cache_unified_iswa::commit () {
16591711 kv_base->commit ();
16601712 kv_swa ->commit ();
16611713
1662- if (pending.pos_max .empty ()) {
1663- return ;
1664- }
1665-
16661714 // slide the attention window, forgetting/pruning old tokens that are outside the window
16671715 for (const auto & [seq_id, pos_max] : pending.pos_max ) {
1668- if (pos_max <= (llama_pos) hparams.n_swa ) {
1669- continue ;
1670- }
1671-
1672- kv_swa->seq_rm (seq_id, -1 , pos_max - hparams.n_swa + 1 );
1716+ kv_swa->prune_swa (seq_id, pos_max);
16731717 }
16741718
1675- pending.pos_max .clear ();
1719+ pending.clear ();
1720+ states.clear ();
16761721}
16771722
16781723bool llama_kv_cache_unified_iswa::update (llama_context & lctx) {
@@ -1695,12 +1740,18 @@ void llama_kv_cache_unified_iswa::set_full() {
16951740}
16961741
16971742llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
1743+ pending.pos_max .clear ();
1744+
16981745 for (int i = 0 ; i < batch.n_tokens ; ++i) {
16991746 for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
17001747 const llama_seq_id seq_id = batch.seq_id [i][s];
17011748 const llama_pos pos = batch.pos [i];
17021749
1703- pending.pos_max [seq_id] = std::max (pending.pos_max [seq_id], pos);
1750+ if (pending.pos_max .find (seq_id) == pending.pos_max .end ()) {
1751+ pending.pos_max [seq_id] = pos;
1752+ } else {
1753+ pending.pos_max [seq_id] = std::max (pending.pos_max [seq_id], pos);
1754+ }
17041755 }
17051756 }
17061757
@@ -1721,6 +1772,11 @@ bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
17211772 return res;
17221773}
17231774
1775+ void llama_kv_cache_unified_iswa::set_state (int i) {
1776+ kv_base->set_state (i);
1777+ kv_swa ->set_state (i);
1778+ }
1779+
17241780int32_t llama_kv_cache_unified_iswa::get_n_tokens () const {
17251781 return kv_base->get_n_tokens ();
17261782}
@@ -2090,6 +2146,8 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
20902146}
20912147
20922148void llama_kv_cache_recurrent::restore () {
2149+ states.clear ();
2150+
20932151 if (pending.ranges .empty ()) {
20942152 return ;
20952153 }
@@ -2098,6 +2156,7 @@ void llama_kv_cache_recurrent::restore() {
20982156}
20992157
21002158void llama_kv_cache_recurrent::commit () {
2159+ states.clear ();
21012160 pending.ranges .clear ();
21022161}
21032162
@@ -2306,6 +2365,11 @@ bool llama_kv_cache_recurrent::find_slot(
23062365 return n >= n_seqs;
23072366}
23082367
2368+ void llama_kv_cache_recurrent::set_state (int i) {
2369+ head = states[i].head ;
2370+ n = states[i].n ;
2371+ }
2372+
23092373int32_t llama_kv_cache_recurrent::get_n_tokens () const {
23102374 int32_t result = 0 ;
23112375
0 commit comments