1111#include < map>
1212#include < stdexcept>
1313
14- static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false };
15-
1614llama_kv_cache_unified::llama_kv_cache_unified (const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
1715}
1816
@@ -206,6 +204,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
206204 return false ;
207205 }
208206 }
207+
208+ return true ;
209209 }
210210
211211 for (uint32_t i = 0 ; i < size; ++i) {
@@ -446,16 +446,66 @@ void llama_kv_cache_unified::defrag() {
446446 }
447447}
448448
449+ void llama_kv_cache_unified::restore () {
450+ if (pending.ranges .empty ()) {
451+ return ;
452+ }
453+
454+ // TODO: tmp - move to llama_kv_cache_recurrent
455+ if (recurrent) {
456+ seq_rm (-1 , -1 , -1 );
457+ return ;
458+ }
459+
460+ uint32_t new_head = size;
461+
462+ for (auto & range : pending.ranges ) {
463+ for (uint32_t i = range.c0 ; i < range.c1 ; ++i) {
464+ cells[i].seq_id .clear ();
465+
466+ // keep count of the number of used cells
467+ if (cells[i].pos >= 0 ) {
468+ used--;
469+ }
470+
471+ cells[i].pos = -1 ;
472+ cells[i].src = -1 ;
473+ }
474+
475+ new_head = std::min (new_head, range.c0 );
476+ }
477+
478+ if (new_head != size && new_head < head) {
479+ head = new_head;
480+ }
481+ }
482+
483+ void llama_kv_cache_unified::commit () {
484+ if (pending.ranges .empty ()) {
485+ LLAMA_LOG_WARN (" %s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n " ,
486+ __func__, " https://github.com/ggml-org/llama.cpp/pull/12695" );
487+ return ;
488+ }
489+
490+ pending.ranges .clear ();
491+ }
492+
449493bool llama_kv_cache_unified::get_can_shift () const {
450494 return can_shift;
451495}
452496
453- llama_kv_cache_slot_info llama_kv_cache_unified::find_slot (
497+ bool llama_kv_cache_unified::find_slot (
454498 const llama_ubatch & ubatch) {
455499 const uint32_t n_tokens = ubatch.n_tokens ;
456500 const uint32_t n_seqs = ubatch.n_seqs ;
457501 const uint32_t n_seq_tokens = ubatch.n_seq_tokens ;
458502
503+ // if we have enough unused cells before the current head ->
504+ // better to start searching from the beginning of the cache, hoping to fill it
505+ if (head > used + 2 *ubatch.n_tokens ) {
506+ head = 0 ;
507+ }
508+
459509 if (recurrent) {
460510 // For recurrent state architectures (like Mamba or RWKV),
461511 // each cache cell can store the state for a whole sequence.
@@ -477,7 +527,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
477527 // too big seq_id
478528 // TODO: would it be possible to resize the cache instead?
479529 LLAMA_LOG_ERROR (" %s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n " , __func__, seq_id, size);
480- return llama_kv_cache_slot_info_failed ;
530+ return false ;
481531 }
482532 if (j > 0 ) {
483533 llama_kv_cell & seq = cells[seq_id];
@@ -616,14 +666,14 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
616666 [](const llama_kv_cell& cell){ return !cell.is_empty (); });
617667
618668 // sanity check
619- return llama_kv_cache_slot_info ( n >= n_seqs) ;
669+ return n >= n_seqs;
620670 }
621671
622672 // otherwise, one cell per token.
623673
624674 if (n_tokens > size) {
625675 LLAMA_LOG_ERROR (" %s: n_tokens = %d > size = %d\n " , __func__, n_tokens, size);
626- return llama_kv_cache_slot_info_failed ;
676+ return false ;
627677 }
628678
629679 uint32_t n_tested = 0 ;
@@ -651,7 +701,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
651701
652702 if (n_tested >= size) {
653703 // LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
654- return llama_kv_cache_slot_info_failed ;
704+ return false ;
655705 }
656706 }
657707
@@ -668,7 +718,9 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
668718
669719 used += n_tokens;
670720
671- return llama_kv_cache_slot_info (head, head + n_tokens);
721+ pending.ranges .push_back ({head, head + n_tokens});
722+
723+ return true ;
672724}
673725
674726uint32_t llama_kv_cache_unified::get_padding (const llama_cparams & cparams) const {
@@ -1033,6 +1085,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
10331085 LLAMA_LOG_ERROR (" %s: failed to find available cells in kv cache\n " , __func__);
10341086 return false ;
10351087 }
1088+ commit ();
10361089
10371090 // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
10381091 // Assume that this is one contiguous block of cells
0 commit comments