Skip to content

Commit 6385d01

Browse files
committed
kv-cache : rework error recovery logic + SWA n_batch -> n_ubatch
ggml-ci
1 parent 3ad524a commit 6385d01

File tree

3 files changed

+71
-69
lines changed

3 files changed

+71
-69
lines changed

src/llama-kv-cache.cpp

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -331,43 +331,44 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
331331
}
332332

333333
void llama_kv_cache_unified::restore() {
334-
if (pending.ranges.empty()) {
334+
if (pending.ubatches.empty()) {
335335
return;
336336
}
337337

338-
// TODO: here we assume that all sequences should be removed from the cache which is not always the case
339-
// need to start keeping more detailed pending information per-sequence
340-
341338
uint32_t new_head = size;
342339

343-
for (auto & range : pending.ranges) {
344-
for (uint32_t i = range.c0; i < range.c1; ++i) {
345-
cells[i].seq_id.clear();
340+
for (const auto & ubatch : pending.ubatches) {
341+
for (uint32_t i = 0; i < ubatch.data.n_tokens; ++i) {
342+
for (int s = 0; s < ubatch.data.n_seq_id[i]; ++s) {
343+
const llama_seq_id seq_id = ubatch.data.seq_id[i][s];
346344

347-
// keep count of the number of used cells
348-
if (cells[i].pos >= 0) {
349-
used--;
350-
}
345+
cells[ubatch.head + i].seq_id.erase(seq_id);
346+
if (cells[ubatch.head + i].seq_id.empty()) {
347+
used--;
351348

352-
cells[i].pos = -1;
353-
}
349+
new_head = std::min(new_head, ubatch.head + i);
350+
}
354351

355-
new_head = std::min(new_head, range.c0);
352+
cells[ubatch.head + i].pos = -1;
353+
}
354+
}
356355
}
357356

358357
if (new_head != size && new_head < head) {
359358
head = new_head;
360359
}
360+
361+
pending.clear();
361362
}
362363

363364
void llama_kv_cache_unified::commit() {
364-
if (pending.ranges.empty()) {
365+
if (pending.ubatches.empty()) {
365366
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
366367
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
367368
return;
368369
}
369370

370-
pending.ranges.clear();
371+
pending.clear();
371372
}
372373

373374
bool llama_kv_cache_unified::update(llama_context & lctx) {
@@ -430,6 +431,8 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
430431
do_defrag = false;
431432
}
432433

434+
pending.clear();
435+
433436
return need_reserve;
434437
}
435438

@@ -459,7 +462,7 @@ llama_sbatch llama_kv_cache_unified::sbatch_init(
459462
llama_ubatch llama_kv_cache_unified::ubatch_next(
460463
llama_sbatch & sbatch,
461464
uint32_t n_ubatch,
462-
bool embd_pooled) const {
465+
bool embd_pooled) {
463466
GGML_UNUSED(embd_pooled);
464467
return sbatch.split_simple(n_ubatch);
465468
}
@@ -519,7 +522,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
519522

520523
used += n_tokens;
521524

522-
pending.ranges.push_back({head, head + n_tokens});
525+
pending.ubatches.push_back({ head, ubatch });
523526

524527
// a heuristic, to avoid attending the full cache if it is not yet utilized
525528
// after enough generations, the benefit from this heuristic disappears
@@ -1568,13 +1571,13 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15681571
bool offload,
15691572
uint32_t kv_size,
15701573
uint32_t n_seq_max,
1571-
uint32_t n_batch,
1574+
uint32_t n_ubatch,
15721575
uint32_t padding) : hparams(model.hparams) {
15731576
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
15741577
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
15751578

15761579
const uint32_t kv_size_base = kv_size;
1577-
const uint32_t kv_size_swa = std::min(kv_size, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, padding));
1580+
const uint32_t kv_size_swa = std::min(kv_size, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, padding));
15781581

15791582
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, kv_size_base);
15801583

@@ -1629,21 +1632,6 @@ void llama_kv_cache_unified_iswa::restore() {
16291632
}
16301633

16311634
void llama_kv_cache_unified_iswa::commit() {
1632-
if (pending.pos_max.empty()) {
1633-
return;
1634-
}
1635-
1636-
// slide the window, forgetting old tokens
1637-
for (const auto & [seq_id, pos_max] : pending.pos_max) {
1638-
if (pos_max <= (llama_pos) hparams.n_swa) {
1639-
continue;
1640-
}
1641-
1642-
kv_swa->seq_rm(seq_id, -1, pos_max - hparams.n_swa + 1);
1643-
}
1644-
1645-
pending.pos_max.clear();
1646-
16471635
kv_base->commit();
16481636
kv_swa ->commit();
16491637
}
@@ -1668,21 +1656,34 @@ void llama_kv_cache_unified_iswa::set_full() {
16681656
}
16691657

16701658
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1671-
// this will be used upon successful decode, during commit, to remove old SWA tokens
1672-
for (int i = 0; i < batch.n_tokens; ++i) {
1673-
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
1674-
const llama_seq_id seq_id = batch.seq_id[i][s];
1675-
const llama_pos pos = batch.pos[i];
1659+
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
1660+
}
1661+
1662+
llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
1663+
GGML_UNUSED(embd_pooled);
1664+
auto res = sbatch.split_simple(n_ubatch);
1665+
1666+
for (uint32_t i = 0; i < res.n_tokens; ++i) {
1667+
for (int s = 0; s < res.n_seq_id[i]; ++s) {
1668+
const llama_seq_id seq_id = res.seq_id[i][s];
1669+
const llama_pos pos = res.pos[i];
16761670

1677-
pending.pos_max[seq_id] = std::max(pending.pos_max[seq_id], pos);
1671+
pos_max_per_seq[seq_id] = std::max(pos_max_per_seq[seq_id], pos);
16781672
}
16791673
}
16801674

1681-
return kv_base->sbatch_init(batch, logits_all);
1682-
}
1675+
// slide the window, forgetting old tokens
1676+
for (const auto & [seq_id, pos_max] : pos_max_per_seq) {
1677+
if (pos_max <= (llama_pos) hparams.n_swa) {
1678+
continue;
1679+
}
16831680

1684-
llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1685-
return kv_base->ubatch_next(sbatch, n_ubatch, embd_pooled);
1681+
kv_swa->seq_rm(seq_id, -1, pos_max - hparams.n_swa + 1);
1682+
}
1683+
1684+
pos_max_per_seq.clear();
1685+
1686+
return res;
16861687
}
16871688

16881689
bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
@@ -2094,7 +2095,7 @@ llama_sbatch llama_kv_cache_recurrent::sbatch_init(
20942095
return llama_sbatch(batch, hparams.n_embd, false, logits_all);
20952096
}
20962097

2097-
llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2098+
llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
20982099
if (embd_pooled) {
20992100
// Pooled embeddings cannot be split across ubatches (yet)
21002101
return sbatch.split_seq(n_ubatch);

src/llama-kv-cache.h

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "llama.h"
44
#include "llama-io.h"
5+
#include "llama-batch.h"
56
#include "llama-graph.h"
67
#include "llama-memory.h"
78

@@ -13,8 +14,6 @@
1314

1415
struct llama_cparams;
1516
struct llama_hparams;
16-
struct llama_ubatch;
17-
struct llama_sbatch;
1817
struct llama_model;
1918
struct llama_context;
2019

@@ -28,7 +27,7 @@ struct llama_kv_cache : public llama_memory_i {
2827
virtual void commit() = 0;
2928

3029
// process any pending defrag/shift/etc. operations
31-
// optionally call once before processing a new batch
30+
// call once before processing a new batch
3231
virtual bool update(llama_context & lctx) = 0;
3332

3433
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
@@ -44,7 +43,7 @@ struct llama_kv_cache : public llama_memory_i {
4443
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
4544

4645
// different KV caches require different batch splitting strategies
47-
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
46+
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) = 0;
4847

4948
// find an empty slot of size "n_tokens" in the cache
5049
virtual bool find_slot(const llama_ubatch & batch) = 0;
@@ -135,7 +134,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
135134
void set_full() override;
136135

137136
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
138-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
137+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) override;
139138

140139
// updates the cache head
141140
// Note: On success, it's important that cache.head points
@@ -178,16 +177,11 @@ class llama_kv_cache_unified : public llama_kv_cache {
178177
const llama_model & model;
179178
const llama_hparams & hparams;
180179

181-
// commit/restore cache
182-
struct slot_range {
183-
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
184-
uint32_t c1 = 0;
185-
};
186-
187180
struct kv_cell {
188181
llama_pos pos = -1;
189182
llama_pos delta = 0;
190183

184+
// TODO: replace with bitset uint64_t
191185
std::set<llama_seq_id> seq_id;
192186

193187
bool has_seq_id(const llama_seq_id & id) const {
@@ -241,10 +235,20 @@ class llama_kv_cache_unified : public llama_kv_cache {
241235
// model layer id -> KV cache layer id
242236
std::map<int32_t, int32_t> map_layer_ids;
243237

244-
// pending cell updates that are not yet committed
245-
// TODO: improve by keeping information per-sequence
238+
struct ubatch_info {
239+
uint32_t head;
240+
241+
llama_ubatch data;
242+
};
243+
244+
// pending cell updates that are not yet committed - cleared upon update()
246245
struct {
247-
std::vector<slot_range> ranges;
246+
void clear() {
247+
ubatches.clear();
248+
}
249+
250+
// upon batch processing failure, we revert these ubatches from the KV cells
251+
std::vector<ubatch_info> ubatches;
248252
} pending;
249253

250254
// defrag
@@ -307,7 +311,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
307311
bool offload,
308312
uint32_t kv_size,
309313
uint32_t n_seq_max,
310-
uint32_t n_batch,
314+
uint32_t n_ubatch,
311315
uint32_t padding);
312316

313317
~llama_kv_cache_unified_iswa() = default;
@@ -340,7 +344,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
340344
void set_full() override;
341345

342346
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
343-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
347+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) override;
344348

345349
bool find_slot(const llama_ubatch & batch) override;
346350

@@ -365,13 +369,10 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
365369
llama_kv_cache_unified * get_kv_swa () const;
366370

367371
private:
368-
// pending cell updates that are not yet committed
369-
struct {
370-
std::map<llama_seq_id, llama_pos> pos_max;
371-
} pending;
372-
373372
const llama_hparams & hparams;
374373

374+
std::map<llama_seq_id, llama_pos> pos_max_per_seq;
375+
375376
std::unique_ptr<llama_kv_cache_unified> kv_base;
376377
std::unique_ptr<llama_kv_cache_unified> kv_swa;
377378
};
@@ -439,7 +440,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
439440
void set_full() override;
440441

441442
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
442-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
443+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) override;
443444

444445
bool find_slot(const llama_ubatch & batch) override;
445446

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13042,7 +13042,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1304213042
cparams.offload_kqv,
1304313043
cparams.n_ctx,
1304413044
cparams.n_seq_max,
13045-
cparams.n_batch,
13045+
cparams.n_ubatch,
1304613046
padding);
1304713047
} else {
1304813048
res = new llama_kv_cache_unified(

0 commit comments

Comments
 (0)