Skip to content

Commit 847e9c8

Browse files
committed
kv-cache : rework error recovery logic
ggml-ci
1 parent d691ff8 commit 847e9c8

File tree

2 files changed

+41
-36
lines changed

2 files changed

+41
-36
lines changed

src/llama-kv-cache.cpp

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -331,43 +331,44 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
331331
}
332332

333333
void llama_kv_cache_unified::restore() {
334-
if (pending.ranges.empty()) {
334+
if (pending.ubatches.empty()) {
335335
return;
336336
}
337337

338-
// TODO: here we assume that all sequences should be removed from the cache which is not always the case
339-
// need to start keeping more detailed pending information per-sequence
340-
341338
uint32_t new_head = size;
342339

343-
for (auto & range : pending.ranges) {
344-
for (uint32_t i = range.c0; i < range.c1; ++i) {
345-
cells[i].seq_id.clear();
340+
for (const auto & ubatch : pending.ubatches) {
341+
for (uint32_t i = 0; i < ubatch.data.n_tokens; ++i) {
342+
for (int s = 0; s < ubatch.data.n_seq_id[i]; ++s) {
343+
const llama_seq_id seq_id = ubatch.data.seq_id[i][s];
346344

347-
// keep count of the number of used cells
348-
if (cells[i].pos >= 0) {
349-
used--;
350-
}
345+
cells[ubatch.head + i].seq_id.erase(seq_id);
346+
if (cells[ubatch.head + i].seq_id.empty()) {
347+
used--;
351348

352-
cells[i].pos = -1;
353-
}
349+
new_head = std::min(new_head, ubatch.head + i);
350+
}
354351

355-
new_head = std::min(new_head, range.c0);
352+
cells[ubatch.head + i].pos = -1;
353+
}
354+
}
356355
}
357356

358357
if (new_head != size && new_head < head) {
359358
head = new_head;
360359
}
360+
361+
pending.clear();
361362
}
362363

363364
void llama_kv_cache_unified::commit() {
364-
if (pending.ranges.empty()) {
365+
if (pending.ubatches.empty()) {
365366
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
366367
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
367368
return;
368369
}
369370

370-
pending.ranges.clear();
371+
pending.clear();
371372
}
372373

373374
bool llama_kv_cache_unified::update(llama_context & lctx) {
@@ -526,7 +527,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
526527

527528
used += n_tokens;
528529

529-
pending.ranges.push_back({head, head + n_tokens});
530+
pending.ubatches.push_back({ head, ubatch });
530531

531532
// a heuristic, to avoid attending the full cache if it is not yet utilized
532533
// after enough generations, the benefit from this heuristic disappears
@@ -1636,11 +1637,14 @@ void llama_kv_cache_unified_iswa::restore() {
16361637
}
16371638

16381639
void llama_kv_cache_unified_iswa::commit() {
1640+
kv_base->commit();
1641+
kv_swa ->commit();
1642+
16391643
if (pending.pos_max.empty()) {
16401644
return;
16411645
}
16421646

1643-
// slide the window, forgetting old tokens
1647+
// slide the attention window, forgetting/pruning old tokens that are outside the window
16441648
for (const auto & [seq_id, pos_max] : pending.pos_max) {
16451649
if (pos_max <= (llama_pos) hparams.n_swa) {
16461650
continue;
@@ -1650,9 +1654,6 @@ void llama_kv_cache_unified_iswa::commit() {
16501654
}
16511655

16521656
pending.pos_max.clear();
1653-
1654-
kv_base->commit();
1655-
kv_swa ->commit();
16561657
}
16571658

16581659
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
@@ -1675,7 +1676,6 @@ void llama_kv_cache_unified_iswa::set_full() {
16751676
}
16761677

16771678
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1678-
// this will be used upon successful decode, during commit, to remove old SWA tokens
16791679
for (int i = 0; i < batch.n_tokens; ++i) {
16801680
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
16811681
const llama_seq_id seq_id = batch.seq_id[i][s];
@@ -1685,11 +1685,12 @@ llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch,
16851685
}
16861686
}
16871687

1688-
return kv_base->sbatch_init(batch, logits_all);
1688+
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
16891689
}
16901690

16911691
llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1692-
return kv_base->ubatch_next(sbatch, n_ubatch, embd_pooled);
1692+
GGML_UNUSED(embd_pooled);
1693+
return sbatch.split_simple(n_ubatch);
16931694
}
16941695

16951696
bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {

src/llama-kv-cache.h

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "llama.h"
44
#include "llama-io.h"
5+
#include "llama-batch.h"
56
#include "llama-graph.h"
67
#include "llama-memory.h"
78

@@ -13,8 +14,6 @@
1314

1415
struct llama_cparams;
1516
struct llama_hparams;
16-
struct llama_ubatch;
17-
struct llama_sbatch;
1817
struct llama_model;
1918
struct llama_context;
2019

@@ -178,16 +177,11 @@ class llama_kv_cache_unified : public llama_kv_cache {
178177
const llama_model & model;
179178
const llama_hparams & hparams;
180179

181-
// commit/restore cache
182-
struct slot_range {
183-
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
184-
uint32_t c1 = 0;
185-
};
186-
187180
struct kv_cell {
188181
llama_pos pos = -1;
189182
llama_pos delta = 0;
190183

184+
// TODO: replace with bitset uint64_t
191185
std::set<llama_seq_id> seq_id;
192186

193187
bool has_seq_id(const llama_seq_id & id) const {
@@ -238,10 +232,20 @@ class llama_kv_cache_unified : public llama_kv_cache {
238232
// model layer id -> KV cache layer id
239233
std::map<int32_t, int32_t> map_layer_ids;
240234

235+
struct ubatch_info {
236+
uint32_t head;
237+
238+
llama_ubatch data;
239+
};
240+
241241
// pending cell updates that are not yet committed
242-
// TODO: improve by keeping information per-sequence
243242
struct {
244-
std::vector<slot_range> ranges;
243+
void clear() {
244+
ubatches.clear();
245+
}
246+
247+
// upon batch processing failure, we revert these ubatches from the KV cells
248+
std::vector<ubatch_info> ubatches;
245249
} pending;
246250

247251
// defrag
@@ -362,13 +366,13 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
362366
llama_kv_cache_unified * get_kv_swa () const;
363367

364368
private:
369+
const llama_hparams & hparams;
370+
365371
// pending cell updates that are not yet committed
366372
struct {
367373
std::map<llama_seq_id, llama_pos> pos_max;
368374
} pending;
369375

370-
const llama_hparams & hparams;
371-
372376
std::unique_ptr<llama_kv_cache_unified> kv_base;
373377
std::unique_ptr<llama_kv_cache_unified> kv_swa;
374378
};

0 commit comments

Comments
 (0)