@@ -331,43 +331,44 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
331331}
332332
333333void llama_kv_cache_unified::restore () {
334- if (pending.ranges .empty ()) {
334+ if (pending.ubatches .empty ()) {
335335 return ;
336336 }
337337
338- // TODO: here we assume that all sequences should be removed from the cache which is not always the case
339- // need to start keeping more detailed pending information per-sequence
340-
341338 uint32_t new_head = size;
342339
343- for (auto & range : pending.ranges ) {
344- for (uint32_t i = range.c0 ; i < range.c1 ; ++i) {
345- cells[i].seq_id .clear ();
340+ for (const auto & ubatch : pending.ubatches ) {
341+ for (uint32_t i = 0 ; i < ubatch.data .n_tokens ; ++i) {
342+ for (int s = 0 ; s < ubatch.data .n_seq_id [i]; ++s) {
343+ const llama_seq_id seq_id = ubatch.data .seq_id [i][s];
346344
347- // keep count of the number of used cells
348- if (cells[i].pos >= 0 ) {
349- used--;
350- }
345+ cells[ubatch.head + i].seq_id .erase (seq_id);
346+ if (cells[ubatch.head + i].seq_id .empty ()) {
347+ used--;
351348
352- cells[i]. pos = - 1 ;
353- }
349+ new_head = std::min (new_head, ubatch. head + i) ;
350+ }
354351
355- new_head = std::min (new_head, range.c0 );
352+ cells[ubatch.head + i].pos = -1 ;
353+ }
354+ }
356355 }
357356
358357 if (new_head != size && new_head < head) {
359358 head = new_head;
360359 }
360+
361+ pending.clear ();
361362}
362363
363364void llama_kv_cache_unified::commit () {
364- if (pending.ranges .empty ()) {
365+ if (pending.ubatches .empty ()) {
365366 LLAMA_LOG_WARN (" %s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n " ,
366367 __func__, " https://github.com/ggml-org/llama.cpp/pull/12695" );
367368 return ;
368369 }
369370
370- pending.ranges . clear ();
371+ pending.clear ();
371372}
372373
373374bool llama_kv_cache_unified::update (llama_context & lctx) {
@@ -430,6 +431,8 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
430431 do_defrag = false ;
431432 }
432433
434+ pending.clear ();
435+
433436 return need_reserve;
434437}
435438
@@ -459,7 +462,7 @@ llama_sbatch llama_kv_cache_unified::sbatch_init(
459462llama_ubatch llama_kv_cache_unified::ubatch_next (
460463 llama_sbatch & sbatch,
461464 uint32_t n_ubatch,
462- bool embd_pooled) const {
465+ bool embd_pooled) {
463466 GGML_UNUSED (embd_pooled);
464467 return sbatch.split_simple (n_ubatch);
465468}
@@ -519,7 +522,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
519522
520523 used += n_tokens;
521524
522- pending.ranges .push_back ({head, head + n_tokens });
525+ pending.ubatches .push_back ({ head, ubatch });
523526
524527 // a heuristic, to avoid attending the full cache if it is not yet utilized
525528 // after enough generations, the benefit from this heuristic disappears
@@ -1568,13 +1571,13 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15681571 bool offload,
15691572 uint32_t kv_size,
15701573 uint32_t n_seq_max,
1571- uint32_t n_batch ,
1574+ uint32_t n_ubatch ,
15721575 uint32_t padding) : hparams(model.hparams) {
15731576 llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
15741577 llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
15751578
15761579 const uint32_t kv_size_base = kv_size;
1577- const uint32_t kv_size_swa = std::min (kv_size, GGML_PAD (hparams.n_swa *n_seq_max + n_batch , padding));
1580+ const uint32_t kv_size_swa = std::min (kv_size, GGML_PAD (hparams.n_swa *n_seq_max + n_ubatch , padding));
15781581
15791582 LLAMA_LOG_INFO (" %s: creating non-SWA KV cache, size = %u cells\n " , __func__, kv_size_base);
15801583
@@ -1629,21 +1632,6 @@ void llama_kv_cache_unified_iswa::restore() {
16291632}
16301633
16311634void llama_kv_cache_unified_iswa::commit () {
1632- if (pending.pos_max .empty ()) {
1633- return ;
1634- }
1635-
1636- // slide the window, forgetting old tokens
1637- for (const auto & [seq_id, pos_max] : pending.pos_max ) {
1638- if (pos_max <= (llama_pos) hparams.n_swa ) {
1639- continue ;
1640- }
1641-
1642- kv_swa->seq_rm (seq_id, -1 , pos_max - hparams.n_swa + 1 );
1643- }
1644-
1645- pending.pos_max .clear ();
1646-
16471635 kv_base->commit ();
16481636 kv_swa ->commit ();
16491637}
@@ -1668,21 +1656,34 @@ void llama_kv_cache_unified_iswa::set_full() {
16681656}
16691657
16701658llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
1671- // this will be used upon successful decode, during commit, to remove old SWA tokens
1672- for (int i = 0 ; i < batch.n_tokens ; ++i) {
1673- for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
1674- const llama_seq_id seq_id = batch.seq_id [i][s];
1675- const llama_pos pos = batch.pos [i];
1659+ return llama_sbatch (batch, hparams.n_embd , true , logits_all);
1660+ }
1661+
1662+ llama_ubatch llama_kv_cache_unified_iswa::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
1663+ GGML_UNUSED (embd_pooled);
1664+ auto res = sbatch.split_simple (n_ubatch);
1665+
1666+ for (uint32_t i = 0 ; i < res.n_tokens ; ++i) {
1667+ for (int s = 0 ; s < res.n_seq_id [i]; ++s) {
1668+ const llama_seq_id seq_id = res.seq_id [i][s];
1669+ const llama_pos pos = res.pos [i];
16761670
1677- pending. pos_max [seq_id] = std::max (pending. pos_max [seq_id], pos);
1671+ pos_max_per_seq [seq_id] = std::max (pos_max_per_seq [seq_id], pos);
16781672 }
16791673 }
16801674
1681- return kv_base->sbatch_init (batch, logits_all);
1682- }
1675+ // slide the window, forgetting old tokens
1676+ for (const auto & [seq_id, pos_max] : pos_max_per_seq) {
1677+ if (pos_max <= (llama_pos) hparams.n_swa ) {
1678+ continue ;
1679+ }
16831680
1684- llama_ubatch llama_kv_cache_unified_iswa::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1685- return kv_base->ubatch_next (sbatch, n_ubatch, embd_pooled);
1681+ kv_swa->seq_rm (seq_id, -1 , pos_max - hparams.n_swa + 1 );
1682+ }
1683+
1684+ pos_max_per_seq.clear ();
1685+
1686+ return res;
16861687}
16871688
16881689bool llama_kv_cache_unified_iswa::find_slot (const llama_ubatch & batch) {
@@ -2094,7 +2095,7 @@ llama_sbatch llama_kv_cache_recurrent::sbatch_init(
20942095 return llama_sbatch (batch, hparams.n_embd , false , logits_all);
20952096}
20962097
2097- llama_ubatch llama_kv_cache_recurrent::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2098+ llama_ubatch llama_kv_cache_recurrent::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
20982099 if (embd_pooled) {
20992100 // Pooled embeddings cannot be split across ubatches (yet)
21002101 return sbatch.split_seq (n_ubatch);
0 commit comments