Skip to content

Commit 206714d

Browse files
committed
kv-cache : rework error recovery logic + SWA n_batch -> n_ubatch
ggml-ci
1 parent 3ad524a commit 206714d

File tree

3 files changed

+100
-69
lines changed

3 files changed

+100
-69
lines changed

src/llama-kv-cache.cpp

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
161161

162162
for (uint32_t i = 0; i < size; ++i) {
163163
if (cells[i].pos >= p0 && cells[i].pos < p1) {
164+
pending.seq_rms.push_back({ seq_id, cells[i].pos, i });
165+
164166
if (seq_id < 0) {
165167
cells[i].seq_id.clear();
166168
} else if (cells[i].has_seq_id(seq_id)) {
@@ -331,43 +333,58 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
331333
}
332334

333335
void llama_kv_cache_unified::restore() {
334-
if (pending.ranges.empty()) {
336+
if (pending.ubatches.empty()) {
335337
return;
336338
}
337339

338-
// TODO: here we assume that all sequences should be removed from the cache which is not always the case
339-
// need to start keeping more detailed pending information per-sequence
340-
341340
uint32_t new_head = size;
342341

343-
for (auto & range : pending.ranges) {
344-
for (uint32_t i = range.c0; i < range.c1; ++i) {
345-
cells[i].seq_id.clear();
342+
for (const auto & ubatch : pending.ubatches) {
343+
for (uint32_t i = 0; i < ubatch.data.n_tokens; ++i) {
344+
for (int s = 0; s < ubatch.data.n_seq_id[i]; ++s) {
345+
const llama_seq_id seq_id = ubatch.data.seq_id[i][s];
346346

347-
// keep count of the number of used cells
348-
if (cells[i].pos >= 0) {
349-
used--;
350-
}
347+
cells[ubatch.head + i].seq_id.erase(seq_id);
348+
if (cells[ubatch.head + i].seq_id.empty()) {
349+
used--;
351350

352-
cells[i].pos = -1;
353-
}
351+
new_head = std::min(new_head, ubatch.head + i);
352+
}
354353

355-
new_head = std::min(new_head, range.c0);
354+
cells[ubatch.head + i].pos = -1;
355+
}
356+
}
356357
}
357358

358359
if (new_head != size && new_head < head) {
359360
head = new_head;
360361
}
362+
363+
for (const auto & seq_rm : pending.seq_rms) {
364+
GGML_ASSERT(seq_rm.seq_id >= 0 && "seq_rm.seq_id < 0 during restore - should not happen");
365+
366+
if (cells[seq_rm.c].seq_id.empty()) {
367+
GGML_ASSERT(cells[seq_rm.c].pos == -1 && "cells[seq_rm.c].pos != -1 during restore - should not happen");
368+
used++;
369+
} else {
370+
GGML_ASSERT(cells[seq_rm.c].pos == seq_rm.p && "cells[seq_rm.c].pos != seq_rm.p during restore - should not happen");
371+
}
372+
373+
cells[seq_rm.c].seq_id.insert(seq_rm.seq_id);
374+
cells[seq_rm.c].pos = seq_rm.p;
375+
}
376+
377+
pending.clear();
361378
}
362379

363380
void llama_kv_cache_unified::commit() {
364-
if (pending.ranges.empty()) {
381+
if (pending.ubatches.empty()) {
365382
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
366383
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
367384
return;
368385
}
369386

370-
pending.ranges.clear();
387+
pending.clear();
371388
}
372389

373390
bool llama_kv_cache_unified::update(llama_context & lctx) {
@@ -430,6 +447,8 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
430447
do_defrag = false;
431448
}
432449

450+
pending.clear();
451+
433452
return need_reserve;
434453
}
435454

@@ -459,7 +478,7 @@ llama_sbatch llama_kv_cache_unified::sbatch_init(
459478
llama_ubatch llama_kv_cache_unified::ubatch_next(
460479
llama_sbatch & sbatch,
461480
uint32_t n_ubatch,
462-
bool embd_pooled) const {
481+
bool embd_pooled) {
463482
GGML_UNUSED(embd_pooled);
464483
return sbatch.split_simple(n_ubatch);
465484
}
@@ -519,7 +538,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
519538

520539
used += n_tokens;
521540

522-
pending.ranges.push_back({head, head + n_tokens});
541+
pending.ubatches.push_back({ head, ubatch });
523542

524543
// a heuristic, to avoid attending the full cache if it is not yet utilized
525544
// after enough generations, the benefit from this heuristic disappears
@@ -1568,13 +1587,13 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15681587
bool offload,
15691588
uint32_t kv_size,
15701589
uint32_t n_seq_max,
1571-
uint32_t n_batch,
1590+
uint32_t n_ubatch,
15721591
uint32_t padding) : hparams(model.hparams) {
15731592
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
15741593
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
15751594

15761595
const uint32_t kv_size_base = kv_size;
1577-
const uint32_t kv_size_swa = std::min(kv_size, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, padding));
1596+
const uint32_t kv_size_swa = std::min(kv_size, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, padding));
15781597

15791598
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, kv_size_base);
15801599

@@ -1629,21 +1648,6 @@ void llama_kv_cache_unified_iswa::restore() {
16291648
}
16301649

16311650
void llama_kv_cache_unified_iswa::commit() {
1632-
if (pending.pos_max.empty()) {
1633-
return;
1634-
}
1635-
1636-
// slide the window, forgetting old tokens
1637-
for (const auto & [seq_id, pos_max] : pending.pos_max) {
1638-
if (pos_max <= (llama_pos) hparams.n_swa) {
1639-
continue;
1640-
}
1641-
1642-
kv_swa->seq_rm(seq_id, -1, pos_max - hparams.n_swa + 1);
1643-
}
1644-
1645-
pending.pos_max.clear();
1646-
16471651
kv_base->commit();
16481652
kv_swa ->commit();
16491653
}
@@ -1668,21 +1672,34 @@ void llama_kv_cache_unified_iswa::set_full() {
16681672
}
16691673

16701674
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1671-
// this will be used upon successful decode, during commit, to remove old SWA tokens
1672-
for (int i = 0; i < batch.n_tokens; ++i) {
1673-
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
1674-
const llama_seq_id seq_id = batch.seq_id[i][s];
1675-
const llama_pos pos = batch.pos[i];
1675+
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
1676+
}
16761677

1677-
pending.pos_max[seq_id] = std::max(pending.pos_max[seq_id], pos);
1678+
llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
1679+
GGML_UNUSED(embd_pooled);
1680+
auto res = sbatch.split_simple(n_ubatch);
1681+
1682+
for (uint32_t i = 0; i < res.n_tokens; ++i) {
1683+
for (int s = 0; s < res.n_seq_id[i]; ++s) {
1684+
const llama_seq_id seq_id = res.seq_id[i][s];
1685+
const llama_pos pos = res.pos[i];
1686+
1687+
pos_max_per_seq[seq_id] = std::max(pos_max_per_seq[seq_id], pos);
16781688
}
16791689
}
16801690

1681-
return kv_base->sbatch_init(batch, logits_all);
1682-
}
1691+
// slide the window, forgetting old tokens
1692+
for (const auto & [seq_id, pos_max] : pos_max_per_seq) {
1693+
if (pos_max <= (llama_pos) hparams.n_swa) {
1694+
continue;
1695+
}
1696+
1697+
kv_swa->seq_rm(seq_id, -1, pos_max - hparams.n_swa + 1);
1698+
}
1699+
1700+
pos_max_per_seq.clear();
16831701

1684-
llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1685-
return kv_base->ubatch_next(sbatch, n_ubatch, embd_pooled);
1702+
return res;
16861703
}
16871704

16881705
bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
@@ -2094,7 +2111,7 @@ llama_sbatch llama_kv_cache_recurrent::sbatch_init(
20942111
return llama_sbatch(batch, hparams.n_embd, false, logits_all);
20952112
}
20962113

2097-
llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2114+
llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
20982115
if (embd_pooled) {
20992116
// Pooled embeddings cannot be split across ubatches (yet)
21002117
return sbatch.split_seq(n_ubatch);

src/llama-kv-cache.h

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "llama.h"
44
#include "llama-io.h"
5+
#include "llama-batch.h"
56
#include "llama-graph.h"
67
#include "llama-memory.h"
78

@@ -13,8 +14,6 @@
1314

1415
struct llama_cparams;
1516
struct llama_hparams;
16-
struct llama_ubatch;
17-
struct llama_sbatch;
1817
struct llama_model;
1918
struct llama_context;
2019

@@ -28,7 +27,7 @@ struct llama_kv_cache : public llama_memory_i {
2827
virtual void commit() = 0;
2928

3029
// process any pending defrag/shift/etc. operations
31-
// optionally call once before processing a new batch
30+
// call once before processing a new batch
3231
virtual bool update(llama_context & lctx) = 0;
3332

3433
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
@@ -44,7 +43,7 @@ struct llama_kv_cache : public llama_memory_i {
4443
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
4544

4645
// different KV caches require different batch splitting strategies
47-
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
46+
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) = 0;
4847

4948
// find an empty slot of size "n_tokens" in the cache
5049
virtual bool find_slot(const llama_ubatch & batch) = 0;
@@ -135,7 +134,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
135134
void set_full() override;
136135

137136
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
138-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
137+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) override;
139138

140139
// updates the cache head
141140
// Note: On success, it's important that cache.head points
@@ -178,16 +177,11 @@ class llama_kv_cache_unified : public llama_kv_cache {
178177
const llama_model & model;
179178
const llama_hparams & hparams;
180179

181-
// commit/restore cache
182-
struct slot_range {
183-
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
184-
uint32_t c1 = 0;
185-
};
186-
187180
struct kv_cell {
188181
llama_pos pos = -1;
189182
llama_pos delta = 0;
190183

184+
// TODO: replace with bitset uint64_t
191185
std::set<llama_seq_id> seq_id;
192186

193187
bool has_seq_id(const llama_seq_id & id) const {
@@ -241,10 +235,33 @@ class llama_kv_cache_unified : public llama_kv_cache {
241235
// model layer id -> KV cache layer id
242236
std::map<int32_t, int32_t> map_layer_ids;
243237

244-
// pending cell updates that are not yet committed
245-
// TODO: improve by keeping information per-sequence
238+
struct ubatch_info {
239+
uint32_t head;
240+
241+
llama_ubatch data;
242+
};
243+
244+
struct seq_rm_info {
245+
llama_seq_id seq_id;
246+
247+
llama_pos p;
248+
249+
uint32_t c;
250+
};
251+
252+
// pending cell updates that are not yet committed - cleared upon update()
246253
struct {
247-
std::vector<slot_range> ranges;
254+
void clear() {
255+
ubatches.clear();
256+
seq_rms.clear();
257+
}
258+
259+
// upon batch processing failure, we revert these ubatches from the KV cells
260+
std::vector<ubatch_info> ubatches;
261+
262+
// any cell removals that occur during the current batch processing will be restored with this information
263+
// this is relevant for SWA caches that perform token pruning on each ubatch
264+
std::vector<seq_rm_info> seq_rms;
248265
} pending;
249266

250267
// defrag
@@ -307,7 +324,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
307324
bool offload,
308325
uint32_t kv_size,
309326
uint32_t n_seq_max,
310-
uint32_t n_batch,
327+
uint32_t n_ubatch,
311328
uint32_t padding);
312329

313330
~llama_kv_cache_unified_iswa() = default;
@@ -340,7 +357,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
340357
void set_full() override;
341358

342359
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
343-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
360+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) override;
344361

345362
bool find_slot(const llama_ubatch & batch) override;
346363

@@ -365,13 +382,10 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
365382
llama_kv_cache_unified * get_kv_swa () const;
366383

367384
private:
368-
// pending cell updates that are not yet committed
369-
struct {
370-
std::map<llama_seq_id, llama_pos> pos_max;
371-
} pending;
372-
373385
const llama_hparams & hparams;
374386

387+
std::map<llama_seq_id, llama_pos> pos_max_per_seq;
388+
375389
std::unique_ptr<llama_kv_cache_unified> kv_base;
376390
std::unique_ptr<llama_kv_cache_unified> kv_swa;
377391
};
@@ -439,7 +453,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
439453
void set_full() override;
440454

441455
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
442-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
456+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) override;
443457

444458
bool find_slot(const llama_ubatch & batch) override;
445459

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13042,7 +13042,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1304213042
cparams.offload_kqv,
1304313043
cparams.n_ctx,
1304413044
cparams.n_seq_max,
13045-
cparams.n_batch,
13045+
cparams.n_ubatch,
1304613046
padding);
1304713047
} else {
1304813048
res = new llama_kv_cache_unified(

0 commit comments

Comments
 (0)