@@ -161,6 +161,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
161161
162162 for (uint32_t i = 0 ; i < size; ++i) {
163163 if (cells[i].pos >= p0 && cells[i].pos < p1) {
164+ pending.seq_rms .push_back ({ seq_id, cells[i].pos , i });
165+
164166 if (seq_id < 0 ) {
165167 cells[i].seq_id .clear ();
166168 } else if (cells[i].has_seq_id (seq_id)) {
@@ -331,43 +333,58 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
331333}
332334
333335void llama_kv_cache_unified::restore () {
334- if (pending.ranges .empty ()) {
336+ if (pending.ubatches .empty ()) {
335337 return ;
336338 }
337339
338- // TODO: here we assume that all sequences should be removed from the cache which is not always the case
339- // need to start keeping more detailed pending information per-sequence
340-
341340 uint32_t new_head = size;
342341
343- for (auto & range : pending.ranges ) {
344- for (uint32_t i = range.c0 ; i < range.c1 ; ++i) {
345- cells[i].seq_id .clear ();
342+ for (const auto & ubatch : pending.ubatches ) {
343+ for (uint32_t i = 0 ; i < ubatch.data .n_tokens ; ++i) {
344+ for (int s = 0 ; s < ubatch.data .n_seq_id [i]; ++s) {
345+ const llama_seq_id seq_id = ubatch.data .seq_id [i][s];
346346
347- // keep count of the number of used cells
348- if (cells[i].pos >= 0 ) {
349- used--;
350- }
347+ cells[ubatch.head + i].seq_id .erase (seq_id);
348+ if (cells[ubatch.head + i].seq_id .empty ()) {
349+ used--;
351350
352- cells[i]. pos = - 1 ;
353- }
351+ new_head = std::min (new_head, ubatch. head + i) ;
352+ }
354353
355- new_head = std::min (new_head, range.c0 );
354+ cells[ubatch.head + i].pos = -1 ;
355+ }
356+ }
356357 }
357358
358359 if (new_head != size && new_head < head) {
359360 head = new_head;
360361 }
362+
363+ for (const auto & seq_rm : pending.seq_rms ) {
364+ GGML_ASSERT (seq_rm.seq_id >= 0 && " seq_rm.seq_id < 0 during restore - should not happen" );
365+
366+ if (cells[seq_rm.c ].seq_id .empty ()) {
367+ GGML_ASSERT (cells[seq_rm.c ].pos == -1 && " cells[seq_rm.c].pos != -1 during restore - should not happen" );
368+ used++;
369+ } else {
370+ GGML_ASSERT (cells[seq_rm.c ].pos == seq_rm.p && " cells[seq_rm.c].pos != seq_rm.p during restore - should not happen" );
371+ }
372+
373+ cells[seq_rm.c ].seq_id .insert (seq_rm.seq_id );
374+ cells[seq_rm.c ].pos = seq_rm.p ;
375+ }
376+
377+ pending.clear ();
361378}
362379
363380void llama_kv_cache_unified::commit () {
364- if (pending.ranges .empty ()) {
381+ if (pending.ubatches .empty ()) {
365382 LLAMA_LOG_WARN (" %s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n " ,
366383 __func__, " https://github.com/ggml-org/llama.cpp/pull/12695" );
367384 return ;
368385 }
369386
370- pending.ranges . clear ();
387+ pending.clear ();
371388}
372389
373390bool llama_kv_cache_unified::update (llama_context & lctx) {
@@ -430,6 +447,8 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
430447 do_defrag = false ;
431448 }
432449
450+ pending.clear ();
451+
433452 return need_reserve;
434453}
435454
@@ -459,7 +478,7 @@ llama_sbatch llama_kv_cache_unified::sbatch_init(
459478llama_ubatch llama_kv_cache_unified::ubatch_next (
460479 llama_sbatch & sbatch,
461480 uint32_t n_ubatch,
462- bool embd_pooled) const {
481+ bool embd_pooled) {
463482 GGML_UNUSED (embd_pooled);
464483 return sbatch.split_simple (n_ubatch);
465484}
@@ -519,7 +538,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
519538
520539 used += n_tokens;
521540
522- pending.ranges .push_back ({head, head + n_tokens });
541+ pending.ubatches .push_back ({ head, ubatch });
523542
524543 // a heuristic, to avoid attending the full cache if it is not yet utilized
525544 // after enough generations, the benefit from this heuristic disappears
@@ -1568,13 +1587,13 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15681587 bool offload,
15691588 uint32_t kv_size,
15701589 uint32_t n_seq_max,
1571- uint32_t n_batch ,
1590+ uint32_t n_ubatch ,
15721591 uint32_t padding) : hparams(model.hparams) {
15731592 llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
15741593 llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
15751594
15761595 const uint32_t kv_size_base = kv_size;
1577- const uint32_t kv_size_swa = std::min (kv_size, GGML_PAD (hparams.n_swa *n_seq_max + n_batch , padding));
1596+ const uint32_t kv_size_swa = std::min (kv_size, GGML_PAD (hparams.n_swa *n_seq_max + n_ubatch , padding));
15781597
15791598 LLAMA_LOG_INFO (" %s: creating non-SWA KV cache, size = %u cells\n " , __func__, kv_size_base);
15801599
@@ -1629,21 +1648,6 @@ void llama_kv_cache_unified_iswa::restore() {
16291648}
16301649
16311650void llama_kv_cache_unified_iswa::commit () {
1632- if (pending.pos_max .empty ()) {
1633- return ;
1634- }
1635-
1636- // slide the window, forgetting old tokens
1637- for (const auto & [seq_id, pos_max] : pending.pos_max ) {
1638- if (pos_max <= (llama_pos) hparams.n_swa ) {
1639- continue ;
1640- }
1641-
1642- kv_swa->seq_rm (seq_id, -1 , pos_max - hparams.n_swa + 1 );
1643- }
1644-
1645- pending.pos_max .clear ();
1646-
16471651 kv_base->commit ();
16481652 kv_swa ->commit ();
16491653}
@@ -1668,21 +1672,34 @@ void llama_kv_cache_unified_iswa::set_full() {
16681672}
16691673
16701674llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
1671- // this will be used upon successful decode, during commit, to remove old SWA tokens
1672- for (int i = 0 ; i < batch.n_tokens ; ++i) {
1673- for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
1674- const llama_seq_id seq_id = batch.seq_id [i][s];
1675- const llama_pos pos = batch.pos [i];
1675+ return llama_sbatch (batch, hparams.n_embd , true , logits_all);
1676+ }
16761677
1677- pending.pos_max [seq_id] = std::max (pending.pos_max [seq_id], pos);
1678+ llama_ubatch llama_kv_cache_unified_iswa::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
1679+ GGML_UNUSED (embd_pooled);
1680+ auto res = sbatch.split_simple (n_ubatch);
1681+
1682+ for (uint32_t i = 0 ; i < res.n_tokens ; ++i) {
1683+ for (int s = 0 ; s < res.n_seq_id [i]; ++s) {
1684+ const llama_seq_id seq_id = res.seq_id [i][s];
1685+ const llama_pos pos = res.pos [i];
1686+
1687+ pos_max_per_seq[seq_id] = std::max (pos_max_per_seq[seq_id], pos);
16781688 }
16791689 }
16801690
1681- return kv_base->sbatch_init (batch, logits_all);
1682- }
1691+ // slide the window, forgetting old tokens
1692+ for (const auto & [seq_id, pos_max] : pos_max_per_seq) {
1693+ if (pos_max <= (llama_pos) hparams.n_swa ) {
1694+ continue ;
1695+ }
1696+
1697+ kv_swa->seq_rm (seq_id, -1 , pos_max - hparams.n_swa + 1 );
1698+ }
1699+
1700+ pos_max_per_seq.clear ();
16831701
1684- llama_ubatch llama_kv_cache_unified_iswa::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1685- return kv_base->ubatch_next (sbatch, n_ubatch, embd_pooled);
1702+ return res;
16861703}
16871704
16881705bool llama_kv_cache_unified_iswa::find_slot (const llama_ubatch & batch) {
@@ -2094,7 +2111,7 @@ llama_sbatch llama_kv_cache_recurrent::sbatch_init(
20942111 return llama_sbatch (batch, hparams.n_embd , false , logits_all);
20952112}
20962113
2097- llama_ubatch llama_kv_cache_recurrent::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2114+ llama_ubatch llama_kv_cache_recurrent::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
20982115 if (embd_pooled) {
20992116 // Pooled embeddings cannot be split across ubatches (yet)
21002117 return sbatch.split_seq (n_ubatch);
0 commit comments