Skip to content

Commit ca52e19

Browse files
committed
kv-cache : keep track of partial SWA computes and print warnings
1 parent 12ee6db commit ca52e19

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

src/llama-kv-cache.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -668,14 +668,20 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
668668
return ggml_cpy(ctx, v_cur, v_view);
669669
}
670670

671-
void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos p1) {
671+
void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax) {
672672
// no pruning is needed when the cache does not use SWA
673673
GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");
674674

675+
int n_attended = 0;
676+
675677
for (uint32_t i = 0; i < size; ++i) {
676678
const llama_pos p0 = cells[i].pos;
677679

678-
if (is_masked_swa(p0, p1)) {
680+
if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
681+
n_attended++;
682+
}
683+
684+
if (is_masked_swa(p0, pmax)) {
679685
if (seq_id < 0) {
680686
cells[i].seq_id.clear();
681687
} else if (cells[i].has_seq_id(seq_id)) {
@@ -694,6 +700,10 @@ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos p1) {
694700
}
695701
}
696702
}
703+
704+
if (n_attended < std::min<int>(n_swa, pmin)) {
705+
LLAMA_LOG_WARN("%s: partial SWA cache detected - possible loss of information, pmin = %d, n_attended = %d, n_swa = %d\n", __func__, pmin, n_attended, n_swa);
706+
}
697707
}
698708

699709
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
@@ -1723,8 +1733,8 @@ void llama_kv_cache_unified_iswa::commit() {
17231733
kv_swa ->commit();
17241734

17251735
// slide the attention window, forgetting/pruning old tokens that are outside the window
1726-
for (const auto & [seq_id, pos_max] : pending.pos_max) {
1727-
kv_swa->prune_swa(seq_id, pos_max);
1736+
for (const auto & [seq_id, entry] : pending.pos) {
1737+
kv_swa->prune_swa(seq_id, entry.pmin, entry.pmax);
17281738
}
17291739

17301740
pending.clear();
@@ -1750,17 +1760,19 @@ void llama_kv_cache_unified_iswa::set_full() {
17501760
}
17511761

17521762
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1753-
pending.pos_max.clear();
1763+
pending.clear();
17541764

17551765
for (int i = 0; i < batch.n_tokens; ++i) {
17561766
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
17571767
const llama_seq_id seq_id = batch.seq_id[i][s];
17581768
const llama_pos pos = batch.pos[i];
17591769

1760-
if (pending.pos_max.find(seq_id) == pending.pos_max.end()) {
1761-
pending.pos_max[seq_id] = pos;
1770+
if (pending.pos.find(seq_id) == pending.pos.end()) {
1771+
pending.pos[seq_id].pmin = pos;
1772+
pending.pos[seq_id].pmax = pos;
17621773
} else {
1763-
pending.pos_max[seq_id] = std::max(pending.pos_max[seq_id], pos);
1774+
pending.pos[seq_id].pmin = std::min(pending.pos[seq_id].pmin, pos);
1775+
pending.pos[seq_id].pmax = std::max(pending.pos[seq_id].pmax, pos);
17641776
}
17651777
}
17661778
}

src/llama-kv-cache.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
178178
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
179179
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
180180

181-
void prune_swa(llama_seq_id seq_id, llama_pos p1);
181+
void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax);
182182

183183
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
184184
void set_input_k_shift (ggml_tensor * dst) const;
@@ -381,11 +381,16 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
381381
const llama_hparams & hparams;
382382

383383
struct {
384+
struct entry {
385+
llama_pos pmin;
386+
llama_pos pmax;
387+
};
388+
384389
void clear() {
385-
pos_max.clear();
390+
pos.clear();
386391
}
387392

388-
std::unordered_map<llama_seq_id, llama_pos> pos_max;
393+
std::unordered_map<llama_seq_id, entry> pos;
389394
} pending;
390395

391396
std::unique_ptr<llama_kv_cache_unified> kv_base;

0 commit comments

Comments
 (0)