Skip to content

Commit dfc8fdc

Browse files
committed
cont : forget old SWA tokens on successful commit
ggml-ci
1 parent acfbbc7 commit dfc8fdc

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

src/llama-kv-cache.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
174174
} else {
175175
continue;
176176
}
177+
177178
if (cells[i].is_empty()) {
178179
// keep count of the number of used cells
179180
if (cells[i].pos >= 0) {
@@ -340,6 +341,9 @@ void llama_kv_cache_unified::restore() {
340341
return;
341342
}
342343

344+
// TODO: here we assume that all sequences should be removed from the cache which is not always the case
345+
// need to start keeping more detailed pending information per-sequence
346+
343347
uint32_t new_head = size;
344348

345349
for (auto & range : pending.ranges) {
@@ -1374,6 +1378,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
13741378
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
13751379
return false;
13761380
}
1381+
13771382
commit();
13781383

13791384
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
@@ -1569,9 +1574,10 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15691574

15701575
// TODO: provide from the llama_context
15711576
const uint32_t n_seq_max = 1;
1577+
const uint32_t n_batch = hparams.n_swa;
15721578

15731579
const uint32_t kv_size_base = kv_size;
1574-
const uint32_t kv_size_swa = hparams.n_swa*n_seq_max;
1580+
const uint32_t kv_size_swa = (hparams.n_swa + n_batch)*n_seq_max;
15751581

15761582
kv_base = std::make_unique<llama_kv_cache_unified>(model, std::move(filter_base), type_k, type_v, v_trans, offload, kv_size_base, padding);
15771583
kv_swa = std::make_unique<llama_kv_cache_unified>(model, std::move(filter_swa), type_k, type_v, v_trans, offload, kv_size_swa, padding);
@@ -1621,6 +1627,21 @@ void llama_kv_cache_unified_iswa::restore() {
16211627
}
16221628

16231629
void llama_kv_cache_unified_iswa::commit() {
1630+
if (pending.pos_max.empty()) {
1631+
return;
1632+
}
1633+
1634+
// slide the window, forgetting old tokens
1635+
for (const auto & [seq_id, pos_max] : pending.pos_max) {
1636+
if (pos_max <= (llama_pos) hparams.n_swa) {
1637+
continue;
1638+
}
1639+
1640+
kv_swa->seq_rm(seq_id, -1, pos_max - hparams.n_swa);
1641+
}
1642+
1643+
pending.pos_max.clear();
1644+
16241645
kv_base->commit();
16251646
kv_swa ->commit();
16261647
}
@@ -1645,6 +1666,16 @@ void llama_kv_cache_unified_iswa::set_full() {
16451666
}
16461667

16471668
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1669+
// this will be used upon successful decode, during commit, to remove old SWA tokens
1670+
for (int i = 0; i < batch.n_tokens; ++i) {
1671+
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
1672+
const llama_seq_id seq_id = batch.seq_id[i][s];
1673+
const llama_pos pos = batch.pos[i];
1674+
1675+
pending.pos_max[seq_id] = std::max(pending.pos_max[seq_id], pos);
1676+
}
1677+
}
1678+
16481679
return kv_base->sbatch_init(batch, logits_all);
16491680
}
16501681

src/llama-kv-cache.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ class llama_kv_cache_unified : public llama_kv_cache {
135135
void set_full() override;
136136

137137
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
138-
139138
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
140139

141140
// updates the cache head
@@ -242,6 +241,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
242241
std::map<int32_t, int32_t> map_layer_ids;
243242

244243
// pending cell updates that are not yet committed
244+
// TODO: improve by keeping information per-sequence
245245
struct {
246246
std::vector<slot_range> ranges;
247247
} pending;
@@ -333,12 +333,8 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
333333
void set_full() override;
334334

335335
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
336-
337336
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
338337

339-
// updates the cache head
340-
// Note: On success, it's important that cache.head points
341-
// to the first cell of the slot.
342338
bool find_slot(const llama_ubatch & batch) override;
343339

344340
int32_t get_n_tokens() const override;
@@ -362,6 +358,11 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
362358
llama_kv_cache_unified * get_kv_swa () const;
363359

364360
private:
361+
// pending cell updates that are not yet committed
362+
struct {
363+
std::map<llama_seq_id, llama_pos> pos_max;
364+
} pending;
365+
365366
const llama_hparams & hparams;
366367

367368
std::unique_ptr<llama_kv_cache_unified> kv_base;
@@ -431,7 +432,6 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
431432
void set_full() override;
432433

433434
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
434-
435435
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
436436

437437
bool find_slot(const llama_ubatch & batch) override;

0 commit comments

Comments
 (0)