@@ -668,14 +668,20 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
668668    return  ggml_cpy (ctx, v_cur, v_view);
669669}
670670
671- void  llama_kv_cache_unified::prune_swa (llama_seq_id seq_id, llama_pos p1 ) {
671+ void  llama_kv_cache_unified::prune_swa (llama_seq_id seq_id, llama_pos pmin, llama_pos pmax ) {
672672    //  no pruning is needed when the cache does not use SWA
673673    GGML_ASSERT (swa_type != LLAMA_SWA_TYPE_NONE && " do not prune non-SWA cache"  );
674674
675+     int  n_attended = 0 ;
676+ 
675677    for  (uint32_t  i = 0 ; i < size; ++i) {
676678        const  llama_pos p0 = cells[i].pos ;
677679
678-         if  (is_masked_swa (p0, p1)) {
680+         if  (p0 <= pmin && !is_masked_swa (p0, pmin)) {
681+             n_attended++;
682+         }
683+ 
684+         if  (is_masked_swa (p0, pmax)) {
679685            if  (seq_id < 0 ) {
680686                cells[i].seq_id .clear ();
681687            } else  if  (cells[i].has_seq_id (seq_id)) {
@@ -694,6 +700,10 @@ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos p1) {
694700            }
695701        }
696702    }
703+ 
704+     if  (n_attended < std::min<int >(n_swa, pmin)) {
705+         LLAMA_LOG_WARN (" %s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n "  , __func__, pmin, n_attended, n_swa);
706+     }
697707}
698708
699709void  llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const  llama_ubatch * ubatch, bool  causal_attn) const  {
@@ -1723,8 +1733,8 @@ void llama_kv_cache_unified_iswa::commit() {
17231733    kv_swa ->commit ();
17241734
17251735    //  slide the attention window, forgetting/pruning old tokens that are outside the window
1726-     for  (const  auto  & [seq_id, pos_max ] : pending.pos_max ) {
1727-         kv_swa->prune_swa (seq_id, pos_max );
1736+     for  (const  auto  & [seq_id, entry ] : pending.pos ) {
1737+         kv_swa->prune_swa (seq_id, entry. pmin , entry. pmax );
17281738    }
17291739
17301740    pending.clear ();
@@ -1750,17 +1760,19 @@ void llama_kv_cache_unified_iswa::set_full() {
17501760}
17511761
17521762llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const  llama_batch & batch, bool  logits_all) {
1753-     pending.pos_max . clear ();
1763+     pending.clear ();
17541764
17551765    for  (int  i = 0 ; i < batch.n_tokens ; ++i) {
17561766        for  (int  s = 0 ; s < batch.n_seq_id [i]; ++s) {
17571767            const  llama_seq_id seq_id = batch.seq_id [i][s];
17581768            const  llama_pos    pos    = batch.pos [i];
17591769
1760-             if  (pending.pos_max .find (seq_id) == pending.pos_max .end ()) {
1761-                 pending.pos_max [seq_id] = pos;
1770+             if  (pending.pos .find (seq_id) == pending.pos .end ()) {
1771+                 pending.pos [seq_id].pmin  = pos;
1772+                 pending.pos [seq_id].pmax  = pos;
17621773            } else  {
1763-                 pending.pos_max [seq_id] = std::max (pending.pos_max [seq_id], pos);
1774+                 pending.pos [seq_id].pmin  = std::min (pending.pos [seq_id].pmin , pos);
1775+                 pending.pos [seq_id].pmax  = std::max (pending.pos [seq_id].pmax , pos);
17641776            }
17651777        }
17661778    }
0 commit comments