@@ -333,31 +333,41 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
333333}
334334
335335void llama_kv_cache_unified::restore () {
336- if (pending.ubatches .empty ()) {
337- return ;
338- }
336+ switch (swa_type) {
337+ case LLAMA_SWA_TYPE_NONE:
338+ {
339+ uint32_t new_head = size;
339340
340- uint32_t new_head = size;
341+ for (const auto & ubatch : pending.ubatches ) {
342+ for (uint32_t i = 0 ; i < ubatch.data .n_tokens ; ++i) {
343+ for (int s = 0 ; s < ubatch.data .n_seq_id [i]; ++s) {
344+ const llama_seq_id seq_id = ubatch.data .seq_id [i][s];
341345
342- for (const auto & ubatch : pending.ubatches ) {
343- for (uint32_t i = 0 ; i < ubatch.data .n_tokens ; ++i) {
344- for (int s = 0 ; s < ubatch.data .n_seq_id [i]; ++s) {
345- const llama_seq_id seq_id = ubatch.data .seq_id [i][s];
346+ cells[ubatch.head + i].seq_id .erase (seq_id);
347+ if (cells[ubatch.head + i].seq_id .empty ()) {
348+ used--;
346349
347- cells[ubatch.head + i].seq_id .erase (seq_id);
348- if (cells[ubatch.head + i].seq_id .empty ()) {
349- used--;
350+ new_head = std::min (new_head, ubatch.head + i);
351+ }
350352
351- new_head = std::min (new_head, ubatch.head + i);
353+ cells[ubatch.head + i].pos = -1 ;
354+ }
355+ }
352356 }
353357
354- cells[ubatch.head + i].pos = -1 ;
355- }
356- }
357- }
358+ if (new_head != size && new_head < head) {
359+ head = new_head;
360+ }
358361
359- if (new_head != size && new_head < head) {
360- head = new_head;
362+ } break ;
363+ case LLAMA_SWA_TYPE_STANDARD:
364+ case LLAMA_SWA_TYPE_CHUNKED:
365+ {
366+ if (!pending.cells_org .empty ()) {
367+ cells = std::move (pending.cells_org );
368+ used = pending.used_org ;
369+ }
370+ } break ;
361371 }
362372
363373 pending.clear ();
@@ -460,16 +470,23 @@ void llama_kv_cache_unified::set_full() {
460470 head = 0 ;
461471}
462472
463- llama_sbatch llama_kv_cache_unified::sbatch_init (
464- const llama_batch & batch,
465- bool logits_all) {
473+ llama_sbatch llama_kv_cache_unified::sbatch_init (const llama_batch & batch, bool logits_all) {
474+ switch (swa_type) {
475+ case LLAMA_SWA_TYPE_NONE:
476+ {
477+ } break ;
478+ case LLAMA_SWA_TYPE_STANDARD:
479+ case LLAMA_SWA_TYPE_CHUNKED:
480+ {
481+ pending.cells_org = cells;
482+ pending.used_org = used;
483+ } break ;
484+ }
485+
466486 return llama_sbatch (batch, hparams.n_embd , true , logits_all);
467487}
468488
469- llama_ubatch llama_kv_cache_unified::ubatch_next (
470- llama_sbatch & sbatch,
471- uint32_t n_ubatch,
472- bool embd_pooled) const {
489+ llama_ubatch llama_kv_cache_unified::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
473490 GGML_UNUSED (embd_pooled);
474491 return sbatch.split_simple (n_ubatch);
475492}
@@ -642,6 +659,33 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
642659 return ggml_cpy (ctx, v_cur, v_view);
643660}
644661
662+ void llama_kv_cache_unified::prune_swa (llama_seq_id seq_id, llama_pos p1) {
663+ GGML_ASSERT (swa_type != LLAMA_SWA_TYPE_NONE);
664+
665+ for (uint32_t i = 0 ; i < size; ++i) {
666+ const llama_pos p0 = cells[i].pos ;
667+
668+ if (is_masked_swa (p0, p1)) {
669+ if (seq_id < 0 ) {
670+ cells[i].seq_id .clear ();
671+ } else if (cells[i].has_seq_id (seq_id)) {
672+ cells[i].seq_id .erase (seq_id);
673+ } else {
674+ continue ;
675+ }
676+
677+ if (cells[i].is_empty ()) {
678+ // keep count of the number of used cells
679+ if (cells[i].pos >= 0 ) {
680+ used--;
681+ }
682+
683+ cells[i].pos = -1 ;
684+ }
685+ }
686+ }
687+ }
688+
645689void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
646690 const int64_t n_tokens = ubatch->n_tokens ;
647691 const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
@@ -1589,13 +1633,13 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15891633 bool offload,
15901634 uint32_t kv_size,
15911635 uint32_t n_seq_max,
1592- uint32_t n_batch ,
1636+ uint32_t n_ubatch ,
15931637 uint32_t padding) : hparams(model.hparams) {
15941638 llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
15951639 llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
15961640
15971641 const uint32_t kv_size_base = kv_size;
1598- const uint32_t kv_size_swa = std::min (kv_size, GGML_PAD (hparams.n_swa *n_seq_max + n_batch , padding));
1642+ const uint32_t kv_size_swa = std::min (kv_size, GGML_PAD (hparams.n_swa *n_seq_max + n_ubatch , padding));
15991643
16001644 LLAMA_LOG_INFO (" %s: creating non-SWA KV cache, size = %u cells\n " , __func__, kv_size_base);
16011645
@@ -1658,21 +1702,6 @@ void llama_kv_cache_unified_iswa::restore() {
16581702void llama_kv_cache_unified_iswa::commit () {
16591703 kv_base->commit ();
16601704 kv_swa ->commit ();
1661-
1662- if (pending.pos_max .empty ()) {
1663- return ;
1664- }
1665-
1666- // slide the attention window, forgetting/pruning old tokens that are outside the window
1667- for (const auto & [seq_id, pos_max] : pending.pos_max ) {
1668- if (pos_max <= (llama_pos) hparams.n_swa ) {
1669- continue ;
1670- }
1671-
1672- kv_swa->seq_rm (seq_id, -1 , pos_max - hparams.n_swa + 1 );
1673- }
1674-
1675- pending.pos_max .clear ();
16761705}
16771706
16781707bool llama_kv_cache_unified_iswa::update (llama_context & lctx) {
@@ -1695,21 +1724,30 @@ void llama_kv_cache_unified_iswa::set_full() {
16951724}
16961725
16971726llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
1698- for (int i = 0 ; i < batch.n_tokens ; ++i) {
1699- for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
1700- const llama_seq_id seq_id = batch.seq_id [i][s];
1701- const llama_pos pos = batch.pos [i];
1727+ return llama_sbatch (batch, hparams.n_embd , true , logits_all);
1728+ }
1729+
1730+ llama_ubatch llama_kv_cache_unified_iswa::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
1731+ GGML_UNUSED (embd_pooled);
1732+ auto res = sbatch.split_simple (n_ubatch);
17021733
1703- pending.pos_max [seq_id] = std::max (pending.pos_max [seq_id], pos);
1734+ pos_max_per_seq.clear ();
1735+
1736+ for (uint32_t i = 0 ; i < res.n_tokens ; ++i) {
1737+ for (int s = 0 ; s < res.n_seq_id [i]; ++s) {
1738+ const llama_seq_id seq_id = res.seq_id [i][s];
1739+ const llama_pos pos = res.pos [i];
1740+
1741+ pos_max_per_seq[seq_id] = std::max (pos_max_per_seq[seq_id], pos);
17041742 }
17051743 }
17061744
1707- return llama_sbatch (batch, hparams.n_embd , true , logits_all);
1708- }
1745+ // slide the attention window, forgetting/pruning old tokens that are outside the window
1746+ for (const auto & [seq_id, pos_max] : pos_max_per_seq) {
1747+ kv_swa->prune_swa (seq_id, pos_max);
1748+ }
17091749
1710- llama_ubatch llama_kv_cache_unified_iswa::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1711- GGML_UNUSED (embd_pooled);
1712- return sbatch.split_simple (n_ubatch);
1750+ return res;
17131751}
17141752
17151753bool llama_kv_cache_unified_iswa::find_slot (const llama_ubatch & batch) {
@@ -2122,7 +2160,7 @@ llama_sbatch llama_kv_cache_recurrent::sbatch_init(
21222160 return llama_sbatch (batch, hparams.n_embd , false , logits_all);
21232161}
21242162
2125- llama_ubatch llama_kv_cache_recurrent::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2163+ llama_ubatch llama_kv_cache_recurrent::ubatch_next (llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
21262164 if (embd_pooled) {
21272165 // Pooled embeddings cannot be split across ubatches (yet)
21282166 return sbatch.split_seq (n_ubatch);
0 commit comments