@@ -333,44 +333,31 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
333333}
334334
335335void llama_kv_cache_unified::restore () {
336- if (pending.ubatches .empty ()) {
337- return ;
338- }
339-
340- uint32_t new_head = size;
341-
342- for (const auto & ubatch : pending.ubatches ) {
343- for (uint32_t i = 0 ; i < ubatch.data .n_tokens ; ++i) {
344- for (int s = 0 ; s < ubatch.data .n_seq_id [i]; ++s) {
345- const llama_seq_id seq_id = ubatch.data .seq_id [i][s];
346-
347- cells[ubatch.head + i].seq_id .erase (seq_id);
348- if (cells[ubatch.head + i].seq_id .empty ()) {
349- used--;
350-
351- new_head = std::min (new_head, ubatch.head + i);
352- }
336+ for (const auto & [id, cell] : recovery.cells ) {
337+ // TODO: move to new `struct kv_cells`
338+ const bool is_empty0 = cells[id].is_empty ();
339+ const bool is_empty1 = cell.is_empty ();
353340
354- cells[ubatch.head + i].pos = -1 ;
355- }
341+ if (!is_empty0 && is_empty1) {
342+ used--;
343+ } else if (is_empty0 && !is_empty1) {
344+ used++;
356345 }
357- }
358346
359- if (new_head != size && new_head < head) {
360- head = new_head;
347+ cells[id] = cell;
361348 }
362349
363- pending .clear ();
350+ recovery .clear ();
364351}
365352
366353void llama_kv_cache_unified::commit () {
367- if (pending. ubatches .empty ()) {
368- LLAMA_LOG_WARN (" %s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n " ,
369- __func__, " https://github.com/ggml-org/llama.cpp/pull/12695 " );
354+ if (recovery. cells .empty ()) {
355+ LLAMA_LOG_WARN (" %s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n " ,
356+ __func__, " https://github.com/ggml-org/llama.cpp/pull/13194 " );
370357 return ;
371358 }
372359
373- pending .clear ();
360+ recovery .clear ();
374361}
375362
376363bool llama_kv_cache_unified::update (llama_context & lctx) {
@@ -460,16 +447,11 @@ void llama_kv_cache_unified::set_full() {
460447 head = 0 ;
461448}
462449
463- llama_sbatch llama_kv_cache_unified::sbatch_init (
464- const llama_batch & batch,
465- bool logits_all) {
450+ llama_sbatch llama_kv_cache_unified::sbatch_init (const llama_batch & batch, bool logits_all) {
466451 return llama_sbatch (batch, hparams.n_embd , true , logits_all);
467452}
468453
469- llama_ubatch llama_kv_cache_unified::ubatch_next (
470- llama_sbatch & sbatch,
471- uint32_t n_ubatch,
472- bool embd_pooled) const {
454+ llama_ubatch llama_kv_cache_unified::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
473455 GGML_UNUSED (embd_pooled);
474456 return sbatch.split_simple (n_ubatch);
475457}
@@ -490,6 +472,29 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
490472 return false ;
491473 }
492474
475+ // #define FIND_SLOT_DEBUG 1
476+ #if FIND_SLOT_DEBUG
477+ LLAMA_LOG_WARN (" begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n " , n, used, head, n_swa);
478+
479+ // for debugging
480+ {
481+ std::string ss;
482+ if (n_swa > 0 ) {
483+ for (uint32_t i = 0 ; i < size; ++i) {
484+ if (cells[i].pos == -1 ) {
485+ ss += ' .' ;
486+ } else {
487+ ss += std::to_string (*cells[i].seq_id .begin ());
488+ }
489+ if (i%256 == 255 ) {
490+ ss += ' \n ' ;
491+ }
492+ }
493+ }
494+ LLAMA_LOG_WARN (" \n %s\n " , ss.c_str ());
495+ }
496+ #endif
497+
493498 uint32_t n_tested = 0 ;
494499
495500 while (true ) {
@@ -520,6 +525,11 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
520525 }
521526
522527 for (uint32_t i = 0 ; i < n_tokens; ++i) {
528+ // remember the original state
529+ if (recovery.cells .find (head + i) == recovery.cells .end ()) {
530+ recovery.cells [head + i] = cells[head + i];
531+ }
532+
523533 cells[head + i].pos = ubatch.pos [i];
524534
525535 for (int32_t j = 0 ; j < ubatch.n_seq_id [i]; j++) {
@@ -529,14 +539,14 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
529539
530540 used += n_tokens;
531541
532- pending.ubatches .push_back ({ head, ubatch });
533-
534542 // a heuristic, to avoid attending the full cache if it is not yet utilized
535543 // after enough generations, the benefit from this heuristic disappears
536544 // if we start defragmenting the cache, the benefit from this will be more important
537545 n = std::min (size, std::max (padding, GGML_PAD (cell_max (), padding)));
538546
539- // printf("n = %5d, used = %5d, head = %5d\n", n, used, head);
547+ #ifdef FIND_SLOT_DEBUG
548+ LLAMA_LOG_WARN (" end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n " , n, used, head, n_swa);
549+ #endif
540550
541551 return true ;
542552}
@@ -642,6 +652,34 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
642652 return ggml_cpy (ctx, v_cur, v_view);
643653}
644654
655+ void llama_kv_cache_unified::prune_swa (llama_seq_id seq_id, llama_pos p1) {
656+ // no pruning is needed when the cache does not use SWA
657+ GGML_ASSERT (swa_type != LLAMA_SWA_TYPE_NONE && " do not prune non-SWA cache" );
658+
659+ for (uint32_t i = 0 ; i < size; ++i) {
660+ const llama_pos p0 = cells[i].pos ;
661+
662+ if (is_masked_swa (p0, p1)) {
663+ if (seq_id < 0 ) {
664+ cells[i].seq_id .clear ();
665+ } else if (cells[i].has_seq_id (seq_id)) {
666+ cells[i].seq_id .erase (seq_id);
667+ } else {
668+ continue ;
669+ }
670+
671+ if (cells[i].is_empty ()) {
672+ // keep count of the number of used cells
673+ if (cells[i].pos >= 0 ) {
674+ used--;
675+ }
676+
677+ cells[i].pos = -1 ;
678+ }
679+ }
680+ }
681+ }
682+
645683void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
646684 const int64_t n_tokens = ubatch->n_tokens ;
647685 const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
@@ -1181,6 +1219,10 @@ uint32_t llama_kv_cache_unified::cell_max() const {
11811219}
11821220
11831221bool llama_kv_cache_unified::is_masked_swa (llama_pos p0, llama_pos p1) const {
1222+ if (p0 < 0 ) {
1223+ return true ;
1224+ }
1225+
11841226 switch (swa_type) {
11851227 case LLAMA_SWA_TYPE_NONE:
11861228 {
@@ -1589,13 +1631,13 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15891631 bool offload,
15901632 uint32_t kv_size,
15911633 uint32_t n_seq_max,
1592- uint32_t n_batch ,
1634+ uint32_t n_ubatch ,
15931635 uint32_t padding) : hparams(model.hparams) {
15941636 llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
15951637 llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
15961638
15971639 const uint32_t kv_size_base = kv_size;
1598- const uint32_t kv_size_swa = std::min (kv_size, GGML_PAD (hparams.n_swa *n_seq_max + n_batch , padding));
1640+ const uint32_t kv_size_swa = std::min (kv_size, GGML_PAD (hparams.n_swa *n_seq_max + n_ubatch , padding));
15991641
16001642 LLAMA_LOG_INFO (" %s: creating non-SWA KV cache, size = %u cells\n " , __func__, kv_size_base);
16011643
@@ -1659,20 +1701,12 @@ void llama_kv_cache_unified_iswa::commit() {
16591701 kv_base->commit ();
16601702 kv_swa ->commit ();
16611703
1662- if (pending.pos_max .empty ()) {
1663- return ;
1664- }
1665-
16661704 // slide the attention window, forgetting/pruning old tokens that are outside the window
16671705 for (const auto & [seq_id, pos_max] : pending.pos_max ) {
1668- if (pos_max <= (llama_pos) hparams.n_swa ) {
1669- continue ;
1670- }
1671-
1672- kv_swa->seq_rm (seq_id, -1 , pos_max - hparams.n_swa + 1 );
1706+ kv_swa->prune_swa (seq_id, pos_max);
16731707 }
16741708
1675- pending.pos_max . clear ();
1709+ pending.clear ();
16761710}
16771711
16781712bool llama_kv_cache_unified_iswa::update (llama_context & lctx) {
@@ -1695,12 +1729,18 @@ void llama_kv_cache_unified_iswa::set_full() {
16951729}
16961730
16971731llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
1732+ pending.pos_max .clear ();
1733+
16981734 for (int i = 0 ; i < batch.n_tokens ; ++i) {
16991735 for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
17001736 const llama_seq_id seq_id = batch.seq_id [i][s];
17011737 const llama_pos pos = batch.pos [i];
17021738
1703- pending.pos_max [seq_id] = std::max (pending.pos_max [seq_id], pos);
1739+ if (pending.pos_max .find (seq_id) == pending.pos_max .end ()) {
1740+ pending.pos_max [seq_id] = pos;
1741+ } else {
1742+ pending.pos_max [seq_id] = std::max (pending.pos_max [seq_id], pos);
1743+ }
17041744 }
17051745 }
17061746
0 commit comments