@@ -331,43 +331,44 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
331331}
332332
333333void llama_kv_cache_unified::restore () {
334- if (pending.ranges .empty ()) {
334+ if (pending.ubatches .empty ()) {
335335 return ;
336336 }
337337
338- // TODO: here we assume that all sequences should be removed from the cache which is not always the case
339- // need to start keeping more detailed pending information per-sequence
340-
341338 uint32_t new_head = size;
342339
343- for (auto & range : pending.ranges ) {
344- for (uint32_t i = range.c0 ; i < range.c1 ; ++i) {
345- cells[i].seq_id .clear ();
340+ for (const auto & ubatch : pending.ubatches ) {
341+ for (uint32_t i = 0 ; i < ubatch.data .n_tokens ; ++i) {
342+ for (int s = 0 ; s < ubatch.data .n_seq_id [i]; ++s) {
343+ const llama_seq_id seq_id = ubatch.data .seq_id [i][s];
346344
347- // keep count of the number of used cells
348- if (cells[i].pos >= 0 ) {
349- used--;
350- }
345+ cells[ubatch.head + i].seq_id .erase (seq_id);
346+ if (cells[ubatch.head + i].seq_id .empty ()) {
347+ used--;
351348
352- cells[i]. pos = - 1 ;
353- }
349+ new_head = std::min (new_head, ubatch. head + i) ;
350+ }
354351
355- new_head = std::min (new_head, range.c0 );
352+ cells[ubatch.head + i].pos = -1 ;
353+ }
354+ }
356355 }
357356
358357 if (new_head != size && new_head < head) {
359358 head = new_head;
360359 }
360+
361+ pending.clear ();
361362}
362363
363364void llama_kv_cache_unified::commit () {
364- if (pending.ranges .empty ()) {
365+ if (pending.ubatches .empty ()) {
365366 LLAMA_LOG_WARN (" %s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n " ,
366367 __func__, " https://github.com/ggml-org/llama.cpp/pull/12695" );
367368 return ;
368369 }
369370
370- pending.ranges . clear ();
371+ pending.clear ();
371372}
372373
373374bool llama_kv_cache_unified::update (llama_context & lctx) {
@@ -526,7 +527,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
526527
527528 used += n_tokens;
528529
529- pending.ranges .push_back ({head, head + n_tokens });
530+ pending.ubatches .push_back ({ head, ubatch });
530531
531532 // a heuristic, to avoid attending the full cache if it is not yet utilized
532533 // after enough generations, the benefit from this heuristic disappears
@@ -1636,11 +1637,14 @@ void llama_kv_cache_unified_iswa::restore() {
16361637}
16371638
16381639void llama_kv_cache_unified_iswa::commit () {
1640+ kv_base->commit ();
1641+ kv_swa ->commit ();
1642+
16391643 if (pending.pos_max .empty ()) {
16401644 return ;
16411645 }
16421646
1643- // slide the window, forgetting old tokens
1647+ // slide the attention window, forgetting/pruning old tokens that are outside the window
16441648 for (const auto & [seq_id, pos_max] : pending.pos_max ) {
16451649 if (pos_max <= (llama_pos) hparams.n_swa ) {
16461650 continue ;
@@ -1650,9 +1654,6 @@ void llama_kv_cache_unified_iswa::commit() {
16501654 }
16511655
16521656 pending.pos_max .clear ();
1653-
1654- kv_base->commit ();
1655- kv_swa ->commit ();
16561657}
16571658
16581659bool llama_kv_cache_unified_iswa::update (llama_context & lctx) {
@@ -1675,7 +1676,6 @@ void llama_kv_cache_unified_iswa::set_full() {
16751676}
16761677
16771678llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
1678- // this will be used upon successful decode, during commit, to remove old SWA tokens
16791679 for (int i = 0 ; i < batch.n_tokens ; ++i) {
16801680 for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
16811681 const llama_seq_id seq_id = batch.seq_id [i][s];
@@ -1685,11 +1685,12 @@ llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch,
16851685 }
16861686 }
16871687
1688- return kv_base-> sbatch_init (batch, logits_all);
1688+ return llama_sbatch (batch, hparams. n_embd , true , logits_all);
16891689}
16901690
16911691llama_ubatch llama_kv_cache_unified_iswa::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1692- return kv_base->ubatch_next (sbatch, n_ubatch, embd_pooled);
1692+ GGML_UNUSED (embd_pooled);
1693+ return sbatch.split_simple (n_ubatch);
16931694}
16941695
16951696bool llama_kv_cache_unified_iswa::find_slot (const llama_ubatch & batch) {
0 commit comments