Skip to content

Commit 7776135

Browse files
committed
kv-cache : rework error recovery logic
ggml-ci
1 parent 3ad524a commit 7776135

File tree

3 files changed

+62
-64
lines changed

3 files changed

+62
-64
lines changed

src/llama-kv-cache.cpp

Lines changed: 44 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -331,43 +331,40 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
331331
}
332332

333333
void llama_kv_cache_unified::restore() {
334-
if (pending.ranges.empty()) {
334+
if (pending.ubatches.empty()) {
335335
return;
336336
}
337337

338-
// TODO: here we assume that all sequences should be removed from the cache which is not always the case
339-
// need to start keeping more detailed pending information per-sequence
340-
341338
uint32_t new_head = size;
342339

343-
for (auto & range : pending.ranges) {
344-
for (uint32_t i = range.c0; i < range.c1; ++i) {
345-
cells[i].seq_id.clear();
340+
for (const auto & ubatch : pending.ubatches) {
341+
for (uint32_t i = 0; i < ubatch.data.n_tokens; ++i) {
342+
for (int s = 0; s < ubatch.data.n_seq_id[i]; ++s) {
343+
const llama_seq_id seq_id = ubatch.data.seq_id[i][s];
346344

347-
// keep count of the number of used cells
348-
if (cells[i].pos >= 0) {
349-
used--;
350-
}
345+
cells[ubatch.head + i].seq_id.erase(seq_id);
346+
if (cells[ubatch.head + i].seq_id.empty()) {
347+
used--;
351348

352-
cells[i].pos = -1;
353-
}
349+
new_head = std::min(new_head, ubatch.head + i);
350+
}
354351

355-
new_head = std::min(new_head, range.c0);
352+
cells[ubatch.head + i].pos = -1;
353+
}
354+
}
356355
}
357356

358-
if (new_head != size && new_head < head) {
359-
head = new_head;
360-
}
357+
pending.ubatches.clear();
361358
}
362359

363360
void llama_kv_cache_unified::commit() {
364-
if (pending.ranges.empty()) {
361+
if (pending.ubatches.empty()) {
365362
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
366363
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
367364
return;
368365
}
369366

370-
pending.ranges.clear();
367+
pending.ubatches.clear();
371368
}
372369

373370
bool llama_kv_cache_unified::update(llama_context & lctx) {
@@ -459,7 +456,7 @@ llama_sbatch llama_kv_cache_unified::sbatch_init(
459456
llama_ubatch llama_kv_cache_unified::ubatch_next(
460457
llama_sbatch & sbatch,
461458
uint32_t n_ubatch,
462-
bool embd_pooled) const {
459+
bool embd_pooled) {
463460
GGML_UNUSED(embd_pooled);
464461
return sbatch.split_simple(n_ubatch);
465462
}
@@ -519,7 +516,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
519516

520517
used += n_tokens;
521518

522-
pending.ranges.push_back({head, head + n_tokens});
519+
pending.ubatches.push_back({ head, ubatch });
523520

524521
// a heuristic, to avoid attending the full cache if it is not yet utilized
525522
// after enough generations, the benefit from this heuristic disappears
@@ -1568,13 +1565,13 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15681565
bool offload,
15691566
uint32_t kv_size,
15701567
uint32_t n_seq_max,
1571-
uint32_t n_batch,
1568+
uint32_t n_ubatch,
15721569
uint32_t padding) : hparams(model.hparams) {
15731570
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
15741571
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
15751572

15761573
const uint32_t kv_size_base = kv_size;
1577-
const uint32_t kv_size_swa = std::min(kv_size, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, padding));
1574+
const uint32_t kv_size_swa = std::min(kv_size, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, padding));
15781575

15791576
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, kv_size_base);
15801577

@@ -1629,21 +1626,6 @@ void llama_kv_cache_unified_iswa::restore() {
16291626
}
16301627

16311628
void llama_kv_cache_unified_iswa::commit() {
1632-
if (pending.pos_max.empty()) {
1633-
return;
1634-
}
1635-
1636-
// slide the window, forgetting old tokens
1637-
for (const auto & [seq_id, pos_max] : pending.pos_max) {
1638-
if (pos_max <= (llama_pos) hparams.n_swa) {
1639-
continue;
1640-
}
1641-
1642-
kv_swa->seq_rm(seq_id, -1, pos_max - hparams.n_swa + 1);
1643-
}
1644-
1645-
pending.pos_max.clear();
1646-
16471629
kv_base->commit();
16481630
kv_swa ->commit();
16491631
}
@@ -1668,21 +1650,34 @@ void llama_kv_cache_unified_iswa::set_full() {
16681650
}
16691651

16701652
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1671-
// this will be used upon successful decode, during commit, to remove old SWA tokens
1672-
for (int i = 0; i < batch.n_tokens; ++i) {
1673-
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
1674-
const llama_seq_id seq_id = batch.seq_id[i][s];
1675-
const llama_pos pos = batch.pos[i];
1653+
return kv_base->sbatch_init(batch, logits_all);
1654+
}
1655+
1656+
llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
1657+
GGML_UNUSED(embd_pooled);
1658+
auto res = sbatch.split_simple(n_ubatch);
1659+
1660+
for (uint32_t i = 0; i < res.n_tokens; ++i) {
1661+
for (int s = 0; s < res.n_seq_id[i]; ++s) {
1662+
const llama_seq_id seq_id = res.seq_id[i][s];
1663+
const llama_pos pos = res.pos[i];
16761664

1677-
pending.pos_max[seq_id] = std::max(pending.pos_max[seq_id], pos);
1665+
pos_max_per_seq[seq_id] = std::max(pos_max_per_seq[seq_id], pos);
16781666
}
16791667
}
16801668

1681-
return kv_base->sbatch_init(batch, logits_all);
1682-
}
1669+
// slide the window, forgetting old tokens
1670+
for (const auto & [seq_id, pos_max] : pos_max_per_seq) {
1671+
if (pos_max <= (llama_pos) hparams.n_swa) {
1672+
continue;
1673+
}
16831674

1684-
llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1685-
return kv_base->ubatch_next(sbatch, n_ubatch, embd_pooled);
1675+
kv_swa->seq_rm(seq_id, -1, pos_max - hparams.n_swa + 1);
1676+
}
1677+
1678+
pos_max_per_seq.clear();
1679+
1680+
return res;
16861681
}
16871682

16881683
bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
@@ -2094,7 +2089,7 @@ llama_sbatch llama_kv_cache_recurrent::sbatch_init(
20942089
return llama_sbatch(batch, hparams.n_embd, false, logits_all);
20952090
}
20962091

2097-
llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2092+
llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
20982093
if (embd_pooled) {
20992094
// Pooled embeddings cannot be split across ubatches (yet)
21002095
return sbatch.split_seq(n_ubatch);

src/llama-kv-cache.h

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "llama.h"
44
#include "llama-io.h"
5+
#include "llama-batch.h"
56
#include "llama-graph.h"
67
#include "llama-memory.h"
78

@@ -13,8 +14,6 @@
1314

1415
struct llama_cparams;
1516
struct llama_hparams;
16-
struct llama_ubatch;
17-
struct llama_sbatch;
1817
struct llama_model;
1918
struct llama_context;
2019

@@ -44,7 +43,7 @@ struct llama_kv_cache : public llama_memory_i {
4443
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
4544

4645
// different KV caches require different batch splitting strategies
47-
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
46+
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) = 0;
4847

4948
// find an empty slot of size "n_tokens" in the cache
5049
virtual bool find_slot(const llama_ubatch & batch) = 0;
@@ -135,7 +134,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
135134
void set_full() override;
136135

137136
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
138-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
137+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) override;
139138

140139
// updates the cache head
141140
// Note: On success, it's important that cache.head points
@@ -188,6 +187,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
188187
llama_pos pos = -1;
189188
llama_pos delta = 0;
190189

190+
// TODO: replace with bitset uint64_t
191191
std::set<llama_seq_id> seq_id;
192192

193193
bool has_seq_id(const llama_seq_id & id) const {
@@ -241,12 +241,18 @@ class llama_kv_cache_unified : public llama_kv_cache {
241241
// model layer id -> KV cache layer id
242242
std::map<int32_t, int32_t> map_layer_ids;
243243

244+
struct ubatch_info {
245+
uint32_t head;
246+
247+
llama_ubatch data;
248+
};
249+
244250
// pending cell updates that are not yet committed
245-
// TODO: improve by keeping information per-sequence
246251
struct {
247-
std::vector<slot_range> ranges;
252+
std::vector<ubatch_info> ubatches;
248253
} pending;
249254

255+
250256
// defrag
251257
struct {
252258
std::vector<uint32_t> ids;
@@ -307,7 +313,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
307313
bool offload,
308314
uint32_t kv_size,
309315
uint32_t n_seq_max,
310-
uint32_t n_batch,
316+
uint32_t n_ubatch,
311317
uint32_t padding);
312318

313319
~llama_kv_cache_unified_iswa() = default;
@@ -340,7 +346,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
340346
void set_full() override;
341347

342348
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
343-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
349+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) override;
344350

345351
bool find_slot(const llama_ubatch & batch) override;
346352

@@ -365,13 +371,10 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
365371
llama_kv_cache_unified * get_kv_swa () const;
366372

367373
private:
368-
// pending cell updates that are not yet committed
369-
struct {
370-
std::map<llama_seq_id, llama_pos> pos_max;
371-
} pending;
372-
373374
const llama_hparams & hparams;
374375

376+
std::map<llama_seq_id, llama_pos> pos_max_per_seq;
377+
375378
std::unique_ptr<llama_kv_cache_unified> kv_base;
376379
std::unique_ptr<llama_kv_cache_unified> kv_swa;
377380
};
@@ -439,7 +442,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
439442
void set_full() override;
440443

441444
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
442-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
445+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) override;
443446

444447
bool find_slot(const llama_ubatch & batch) override;
445448

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13042,7 +13042,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1304213042
cparams.offload_kqv,
1304313043
cparams.n_ctx,
1304413044
cparams.n_seq_max,
13045-
cparams.n_batch,
13045+
cparams.n_ubatch,
1304613046
padding);
1304713047
} else {
1304813048
res = new llama_kv_cache_unified(

0 commit comments

Comments
 (0)