@@ -333,44 +333,32 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
333333}
334334
335335void llama_kv_cache_unified::restore () {
336- if (pending.ubatches .empty ()) {
337- return ;
338- }
339-
340- uint32_t new_head = size;
341-
342- for (const auto & ubatch : pending.ubatches ) {
343- for (uint32_t i = 0 ; i < ubatch.data .n_tokens ; ++i) {
344- for (int s = 0 ; s < ubatch.data .n_seq_id [i]; ++s) {
345- const llama_seq_id seq_id = ubatch.data .seq_id [i][s];
336+ for (const auto & [id, cell] : recovery.cells ) {
337+ // TODO: move to new `struct kv_cells`
338+ const bool is_empty0 = cells[id].is_empty ();
339+ const bool is_empty1 = cell.is_empty ();
346340
347- cells[ubatch.head + i].seq_id .erase (seq_id);
348- if (cells[ubatch.head + i].seq_id .empty ()) {
349- used--;
350-
351- new_head = std::min (new_head, ubatch.head + i);
352- }
353-
354- cells[ubatch.head + i].pos = -1 ;
355- }
341+ if (!is_empty0 && is_empty1) {
342+ used--;
343+ } else if (is_empty0 && !is_empty1) {
344+ used++;
356345 }
357- }
358346
359- if (new_head != size && new_head < head) {
360- head = new_head;
347+ cells[id] = cell;
361348 }
362349
363- pending.clear ();
350+ recovery.clear ();
351+ states.clear ();
364352}
365353
366354void llama_kv_cache_unified::commit () {
367- if (pending. ubatches .empty ()) {
368- LLAMA_LOG_WARN (" %s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n " ,
369- __func__, " https://github.com/ggml-org/llama.cpp/pull/12695 " );
355+ if (recovery. cells .empty ()) {
356+ LLAMA_LOG_WARN (" %s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n " ,
357+ __func__, " https://github.com/ggml-org/llama.cpp/pull/13194 " );
370358 return ;
371359 }
372360
373- pending .clear ();
361+ recovery .clear ();
374362}
375363
376364bool llama_kv_cache_unified::update (llama_context & lctx) {
@@ -460,16 +448,11 @@ void llama_kv_cache_unified::set_full() {
460448 head = 0 ;
461449}
462450
463- llama_sbatch llama_kv_cache_unified::sbatch_init (
464- const llama_batch & batch,
465- bool logits_all) {
451+ llama_sbatch llama_kv_cache_unified::sbatch_init (const llama_batch & batch, bool logits_all) {
466452 return llama_sbatch (batch, hparams.n_embd , true , logits_all);
467453}
468454
469- llama_ubatch llama_kv_cache_unified::ubatch_next (
470- llama_sbatch & sbatch,
471- uint32_t n_ubatch,
472- bool embd_pooled) const {
455+ llama_ubatch llama_kv_cache_unified::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
473456 GGML_UNUSED (embd_pooled);
474457 return sbatch.split_simple (n_ubatch);
475458}
@@ -490,6 +473,29 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
490473 return false ;
491474 }
492475
476+ // #define FIND_SLOT_DEBUG 1
477+ #if FIND_SLOT_DEBUG
478+ LLAMA_LOG_WARN (" begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n " , n, used, head, n_swa);
479+
480+ // for debugging
481+ {
482+ std::string ss;
483+ if (n_swa > 0 ) {
484+ for (uint32_t i = 0 ; i < size; ++i) {
485+ if (cells[i].pos == -1 ) {
486+ ss += ' .' ;
487+ } else {
488+ ss += std::to_string (*cells[i].seq_id .begin ());
489+ }
490+ if (i%256 == 255 ) {
491+ ss += ' \n ' ;
492+ }
493+ }
494+ }
495+ LLAMA_LOG_WARN (" \n %s\n " , ss.c_str ());
496+ }
497+ #endif
498+
493499 uint32_t n_tested = 0 ;
494500
495501 while (true ) {
@@ -520,6 +526,11 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
520526 }
521527
522528 for (uint32_t i = 0 ; i < n_tokens; ++i) {
529+ // remember the original state
530+ if (recovery.cells .find (head + i) == recovery.cells .end ()) {
531+ recovery.cells [head + i] = cells[head + i];
532+ }
533+
523534 cells[head + i].pos = ubatch.pos [i];
524535
525536 for (int32_t j = 0 ; j < ubatch.n_seq_id [i]; j++) {
@@ -529,18 +540,25 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
529540
530541 used += n_tokens;
531542
532- pending.ubatches .push_back ({ head, ubatch });
533-
534543 // a heuristic, to avoid attending the full cache if it is not yet utilized
535544 // after enough generations, the benefit from this heuristic disappears
536545 // if we start defragmenting the cache, the benefit from this will be more important
537546 n = std::min (size, std::max (padding, GGML_PAD (cell_max (), padding)));
538547
539- // printf("n = %5d, used = %5d, head = %5d\n", n, used, head);
548+ states.push_back ({head, n});
549+
550+ #ifdef FIND_SLOT_DEBUG
551+ LLAMA_LOG_WARN (" end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n " , n, used, head, n_swa);
552+ #endif
540553
541554 return true ;
542555}
543556
557+ void llama_kv_cache_unified::set_state (int i) {
558+ head = states[i].head ;
559+ n = states[i].n ;
560+ }
561+
544562int32_t llama_kv_cache_unified::get_n_tokens () const {
545563 int32_t result = 0 ;
546564
@@ -642,6 +660,34 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
642660 return ggml_cpy (ctx, v_cur, v_view);
643661}
644662
663+ void llama_kv_cache_unified::prune_swa (llama_seq_id seq_id, llama_pos p1) {
664+ // no pruning is needed when the cache does not use SWA
665+ GGML_ASSERT (swa_type != LLAMA_SWA_TYPE_NONE && " do not prune non-SWA cache" );
666+
667+ for (uint32_t i = 0 ; i < size; ++i) {
668+ const llama_pos p0 = cells[i].pos ;
669+
670+ if (is_masked_swa (p0, p1)) {
671+ if (seq_id < 0 ) {
672+ cells[i].seq_id .clear ();
673+ } else if (cells[i].has_seq_id (seq_id)) {
674+ cells[i].seq_id .erase (seq_id);
675+ } else {
676+ continue ;
677+ }
678+
679+ if (cells[i].is_empty ()) {
680+ // keep count of the number of used cells
681+ if (cells[i].pos >= 0 ) {
682+ used--;
683+ }
684+
685+ cells[i].pos = -1 ;
686+ }
687+ }
688+ }
689+ }
690+
645691void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
646692 const int64_t n_tokens = ubatch->n_tokens ;
647693 const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
@@ -1181,6 +1227,10 @@ uint32_t llama_kv_cache_unified::cell_max() const {
11811227}
11821228
11831229bool llama_kv_cache_unified::is_masked_swa (llama_pos p0, llama_pos p1) const {
1230+ if (p0 < 0 ) {
1231+ return true ;
1232+ }
1233+
11841234 switch (swa_type) {
11851235 case LLAMA_SWA_TYPE_NONE:
11861236 {
@@ -1653,26 +1703,19 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
16531703void llama_kv_cache_unified_iswa::restore () {
16541704 kv_base->restore ();
16551705 kv_swa ->restore ();
1706+ states.clear ();
16561707}
16571708
16581709void llama_kv_cache_unified_iswa::commit () {
16591710 kv_base->commit ();
16601711 kv_swa ->commit ();
16611712
1662- if (pending.pos_max .empty ()) {
1663- return ;
1664- }
1665-
16661713 // slide the attention window, forgetting/pruning old tokens that are outside the window
16671714 for (const auto & [seq_id, pos_max] : pending.pos_max ) {
1668- if (pos_max <= (llama_pos) hparams.n_swa ) {
1669- continue ;
1670- }
1671-
1672- kv_swa->seq_rm (seq_id, -1 , pos_max - hparams.n_swa + 1 );
1715+ kv_swa->prune_swa (seq_id, pos_max);
16731716 }
16741717
1675- pending.pos_max . clear ();
1718+ pending.clear ();
16761719}
16771720
16781721bool llama_kv_cache_unified_iswa::update (llama_context & lctx) {
@@ -1695,12 +1738,18 @@ void llama_kv_cache_unified_iswa::set_full() {
16951738}
16961739
16971740llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
1741+ pending.pos_max .clear ();
1742+
16981743 for (int i = 0 ; i < batch.n_tokens ; ++i) {
16991744 for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
17001745 const llama_seq_id seq_id = batch.seq_id [i][s];
17011746 const llama_pos pos = batch.pos [i];
17021747
1703- pending.pos_max [seq_id] = std::max (pending.pos_max [seq_id], pos);
1748+ if (pending.pos_max .find (seq_id) == pending.pos_max .end ()) {
1749+ pending.pos_max [seq_id] = pos;
1750+ } else {
1751+ pending.pos_max [seq_id] = std::max (pending.pos_max [seq_id], pos);
1752+ }
17041753 }
17051754 }
17061755
@@ -1721,6 +1770,11 @@ bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
17211770 return res;
17221771}
17231772
1773+ void llama_kv_cache_unified_iswa::set_state (int i) {
1774+ kv_base->set_state (i);
1775+ kv_swa ->set_state (i);
1776+ }
1777+
17241778int32_t llama_kv_cache_unified_iswa::get_n_tokens () const {
17251779 return kv_base->get_n_tokens ();
17261780}
@@ -2090,6 +2144,8 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
20902144}
20912145
20922146void llama_kv_cache_recurrent::restore () {
2147+ states.clear ();
2148+
20932149 if (pending.ranges .empty ()) {
20942150 return ;
20952151 }
@@ -2306,6 +2362,11 @@ bool llama_kv_cache_recurrent::find_slot(
23062362 return n >= n_seqs;
23072363}
23082364
2365+ void llama_kv_cache_recurrent::set_state (int i) {
2366+ head = states[i].head ;
2367+ n = states[i].n ;
2368+ }
2369+
23092370int32_t llama_kv_cache_recurrent::get_n_tokens () const {
23102371 int32_t result = 0 ;
23112372
0 commit comments