@@ -331,43 +331,40 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
331331}
332332
333333void llama_kv_cache_unified::restore () {
334- if (pending.ranges .empty ()) {
334+ if (pending.ubatches .empty ()) {
335335 return ;
336336 }
337337
338- // TODO: here we assume that all sequences should be removed from the cache which is not always the case
339- // need to start keeping more detailed pending information per-sequence
340-
341338 uint32_t new_head = size;
342339
343- for (auto & range : pending.ranges ) {
344- for (uint32_t i = range.c0 ; i < range.c1 ; ++i) {
345- cells[i].seq_id .clear ();
340+ for (const auto & ubatch : pending.ubatches ) {
341+ for (uint32_t i = 0 ; i < ubatch.data .n_tokens ; ++i) {
342+ for (int s = 0 ; s < ubatch.data .n_seq_id [i]; ++s) {
343+ const llama_seq_id seq_id = ubatch.data .seq_id [i][s];
346344
347- // keep count of the number of used cells
348- if (cells[i].pos >= 0 ) {
349- used--;
350- }
345+ cells[ubatch.head + i].seq_id .erase (seq_id);
346+ if (cells[ubatch.head + i].seq_id .empty ()) {
347+ used--;
351348
352- cells[i]. pos = - 1 ;
353- }
349+ new_head = std::min (new_head, ubatch. head + i) ;
350+ }
354351
355- new_head = std::min (new_head, range.c0 );
352+ cells[ubatch.head + i].pos = -1 ;
353+ }
354+ }
356355 }
357356
358- if (new_head != size && new_head < head) {
359- head = new_head;
360- }
357+ pending.ubatches .clear ();
361358}
362359
363360void llama_kv_cache_unified::commit () {
364- if (pending.ranges .empty ()) {
361+ if (pending.ubatches .empty ()) {
365362 LLAMA_LOG_WARN (" %s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n " ,
366363 __func__, " https://github.com/ggml-org/llama.cpp/pull/12695" );
367364 return ;
368365 }
369366
370- pending.ranges .clear ();
367+ pending.ubatches .clear ();
371368}
372369
373370bool llama_kv_cache_unified::update (llama_context & lctx) {
@@ -459,7 +456,7 @@ llama_sbatch llama_kv_cache_unified::sbatch_init(
459456llama_ubatch llama_kv_cache_unified::ubatch_next (
460457 llama_sbatch & sbatch,
461458 uint32_t n_ubatch,
462- bool embd_pooled) const {
459+ bool embd_pooled) {
463460 GGML_UNUSED (embd_pooled);
464461 return sbatch.split_simple (n_ubatch);
465462}
@@ -519,7 +516,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
519516
520517 used += n_tokens;
521518
522- pending.ranges .push_back ({head, head + n_tokens });
519+ pending.ubatches .push_back ({ head, ubatch });
523520
524521 // a heuristic, to avoid attending the full cache if it is not yet utilized
525522 // after enough generations, the benefit from this heuristic disappears
@@ -1568,13 +1565,13 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15681565 bool offload,
15691566 uint32_t kv_size,
15701567 uint32_t n_seq_max,
1571- uint32_t n_batch ,
1568+ uint32_t n_ubatch ,
15721569 uint32_t padding) : hparams(model.hparams) {
15731570 llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
15741571 llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
15751572
15761573 const uint32_t kv_size_base = kv_size;
1577- const uint32_t kv_size_swa = std::min (kv_size, GGML_PAD (hparams.n_swa *n_seq_max + n_batch , padding));
1574+ const uint32_t kv_size_swa = std::min (kv_size, GGML_PAD (hparams.n_swa *n_seq_max + n_ubatch , padding));
15781575
15791576 LLAMA_LOG_INFO (" %s: creating non-SWA KV cache, size = %u cells\n " , __func__, kv_size_base);
15801577
@@ -1629,21 +1626,6 @@ void llama_kv_cache_unified_iswa::restore() {
16291626}
16301627
16311628void llama_kv_cache_unified_iswa::commit () {
1632- if (pending.pos_max .empty ()) {
1633- return ;
1634- }
1635-
1636- // slide the window, forgetting old tokens
1637- for (const auto & [seq_id, pos_max] : pending.pos_max ) {
1638- if (pos_max <= (llama_pos) hparams.n_swa ) {
1639- continue ;
1640- }
1641-
1642- kv_swa->seq_rm (seq_id, -1 , pos_max - hparams.n_swa + 1 );
1643- }
1644-
1645- pending.pos_max .clear ();
1646-
16471629 kv_base->commit ();
16481630 kv_swa ->commit ();
16491631}
@@ -1668,21 +1650,34 @@ void llama_kv_cache_unified_iswa::set_full() {
16681650}
16691651
16701652llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
1671- // this will be used upon successful decode, during commit, to remove old SWA tokens
1672- for (int i = 0 ; i < batch.n_tokens ; ++i) {
1673- for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
1674- const llama_seq_id seq_id = batch.seq_id [i][s];
1675- const llama_pos pos = batch.pos [i];
1653+ return llama_sbatch (batch, hparams.n_embd , true , logits_all);
1654+ }
1655+
1656+ llama_ubatch llama_kv_cache_unified_iswa::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
1657+ GGML_UNUSED (embd_pooled);
1658+ auto res = sbatch.split_simple (n_ubatch);
16761659
1677- pending.pos_max [seq_id] = std::max (pending.pos_max [seq_id], pos);
1660+ for (uint32_t i = 0 ; i < res.n_tokens ; ++i) {
1661+ for (int s = 0 ; s < res.n_seq_id [i]; ++s) {
1662+ const llama_seq_id seq_id = res.seq_id [i][s];
1663+ const llama_pos pos = res.pos [i];
1664+
1665+ pos_max_per_seq[seq_id] = std::max (pos_max_per_seq[seq_id], pos);
16781666 }
16791667 }
16801668
1681- return kv_base->sbatch_init (batch, logits_all);
1682- }
1669+ // slide the window, forgetting old tokens
1670+ for (const auto & [seq_id, pos_max] : pos_max_per_seq) {
1671+ if (pos_max <= (llama_pos) hparams.n_swa ) {
1672+ continue ;
1673+ }
16831674
1684- llama_ubatch llama_kv_cache_unified_iswa::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1685- return kv_base->ubatch_next (sbatch, n_ubatch, embd_pooled);
1675+ kv_swa->seq_rm (seq_id, -1 , pos_max - hparams.n_swa + 1 );
1676+ }
1677+
1678+ pos_max_per_seq.clear ();
1679+
1680+ return res;
16861681}
16871682
16881683bool llama_kv_cache_unified_iswa::find_slot (const llama_ubatch & batch) {
@@ -2094,7 +2089,7 @@ llama_sbatch llama_kv_cache_recurrent::sbatch_init(
20942089 return llama_sbatch (batch, hparams.n_embd , false , logits_all);
20952090}
20962091
2097- llama_ubatch llama_kv_cache_recurrent::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2092+ llama_ubatch llama_kv_cache_recurrent::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
20982093 if (embd_pooled) {
20992094 // Pooled embeddings cannot be split across ubatches (yet)
21002095 return sbatch.split_seq (n_ubatch);
0 commit comments