@@ -331,43 +331,44 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
331331}
332332
333333void llama_kv_cache_unified::restore () {
334- if (pending.ranges .empty ()) {
334+ if (pending.ubatches .empty ()) {
335335 return ;
336336 }
337337
338- // TODO: here we assume that all sequences should be removed from the cache which is not always the case
339- // need to start keeping more detailed pending information per-sequence
340-
341338 uint32_t new_head = size;
342339
343- for (auto & range : pending.ranges ) {
344- for (uint32_t i = range.c0 ; i < range.c1 ; ++i) {
345- cells[i].seq_id .clear ();
340+ for (const auto & ubatch : pending.ubatches ) {
341+ for (uint32_t i = 0 ; i < ubatch.data .n_tokens ; ++i) {
342+ for (int s = 0 ; s < ubatch.data .n_seq_id [i]; ++s) {
343+ const llama_seq_id seq_id = ubatch.data .seq_id [i][s];
346344
347- // keep count of the number of used cells
348- if (cells[i].pos >= 0 ) {
349- used--;
350- }
345+ cells[ubatch.head + i].seq_id .erase (seq_id);
346+ if (cells[ubatch.head + i].seq_id .empty ()) {
347+ used--;
351348
352- cells[i]. pos = - 1 ;
353- }
349+ new_head = std::min (new_head, ubatch. head + i) ;
350+ }
354351
355- new_head = std::min (new_head, range.c0 );
352+ cells[ubatch.head + i].pos = -1 ;
353+ }
354+ }
356355 }
357356
358357 if (new_head != size && new_head < head) {
359358 head = new_head;
360359 }
360+
361+ pending.clear ();
361362}
362363
363364void llama_kv_cache_unified::commit () {
364- if (pending.ranges .empty ()) {
365+ if (pending.ubatches .empty ()) {
365366 LLAMA_LOG_WARN (" %s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n " ,
366367 __func__, " https://github.com/ggml-org/llama.cpp/pull/12695" );
367368 return ;
368369 }
369370
370- pending.ranges . clear ();
371+ pending.clear ();
371372}
372373
373374bool llama_kv_cache_unified::update (llama_context & lctx) {
@@ -519,7 +520,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
519520
520521 used += n_tokens;
521522
522- pending.ranges .push_back ({head, head + n_tokens });
523+ pending.ubatches .push_back ({ head, ubatch });
523524
524525 // a heuristic, to avoid attending the full cache if it is not yet utilized
525526 // after enough generations, the benefit from this heuristic disappears
@@ -1629,11 +1630,14 @@ void llama_kv_cache_unified_iswa::restore() {
16291630}
16301631
16311632void llama_kv_cache_unified_iswa::commit () {
1633+ kv_base->commit ();
1634+ kv_swa ->commit ();
1635+
16321636 if (pending.pos_max .empty ()) {
16331637 return ;
16341638 }
16351639
1636- // slide the window, forgetting old tokens
1640+ // slide the attention window, forgetting/pruning old tokens that are outside the window
16371641 for (const auto & [seq_id, pos_max] : pending.pos_max ) {
16381642 if (pos_max <= (llama_pos) hparams.n_swa ) {
16391643 continue ;
@@ -1643,9 +1647,6 @@ void llama_kv_cache_unified_iswa::commit() {
16431647 }
16441648
16451649 pending.pos_max .clear ();
1646-
1647- kv_base->commit ();
1648- kv_swa ->commit ();
16491650}
16501651
16511652bool llama_kv_cache_unified_iswa::update (llama_context & lctx) {
@@ -1668,7 +1669,6 @@ void llama_kv_cache_unified_iswa::set_full() {
16681669}
16691670
16701671llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
1671- // this will be used upon successful decode, during commit, to remove old SWA tokens
16721672 for (int i = 0 ; i < batch.n_tokens ; ++i) {
16731673 for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
16741674 const llama_seq_id seq_id = batch.seq_id [i][s];
@@ -1678,11 +1678,12 @@ llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch,
16781678 }
16791679 }
16801680
1681- return kv_base-> sbatch_init (batch, logits_all);
1681+ return llama_sbatch (batch, hparams. n_embd , true , logits_all);
16821682}
16831683
16841684llama_ubatch llama_kv_cache_unified_iswa::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1685- return kv_base->ubatch_next (sbatch, n_ubatch, embd_pooled);
1685+ GGML_UNUSED (embd_pooled);
1686+ return sbatch.split_simple (n_ubatch);
16861687}
16871688
16881689bool llama_kv_cache_unified_iswa::find_slot (const llama_ubatch & batch) {
0 commit comments