Skip to content

Commit cf33051

Browse files
committed
kv-cache : rework SWA logic to support n_ubatch + recovery
ggml-ci
1 parent cbacd62 commit cf33051

File tree

3 files changed

+105
-63
lines changed

3 files changed

+105
-63
lines changed

src/llama-kv-cache.cpp

Lines changed: 91 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -333,31 +333,41 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
333333
}
334334

335335
void llama_kv_cache_unified::restore() {
336-
if (pending.ubatches.empty()) {
337-
return;
338-
}
336+
switch (swa_type) {
337+
case LLAMA_SWA_TYPE_NONE:
338+
{
339+
uint32_t new_head = size;
339340

340-
uint32_t new_head = size;
341+
for (const auto & ubatch : pending.ubatches) {
342+
for (uint32_t i = 0; i < ubatch.data.n_tokens; ++i) {
343+
for (int s = 0; s < ubatch.data.n_seq_id[i]; ++s) {
344+
const llama_seq_id seq_id = ubatch.data.seq_id[i][s];
341345

342-
for (const auto & ubatch : pending.ubatches) {
343-
for (uint32_t i = 0; i < ubatch.data.n_tokens; ++i) {
344-
for (int s = 0; s < ubatch.data.n_seq_id[i]; ++s) {
345-
const llama_seq_id seq_id = ubatch.data.seq_id[i][s];
346+
cells[ubatch.head + i].seq_id.erase(seq_id);
347+
if (cells[ubatch.head + i].seq_id.empty()) {
348+
used--;
346349

347-
cells[ubatch.head + i].seq_id.erase(seq_id);
348-
if (cells[ubatch.head + i].seq_id.empty()) {
349-
used--;
350+
new_head = std::min(new_head, ubatch.head + i);
351+
}
350352

351-
new_head = std::min(new_head, ubatch.head + i);
353+
cells[ubatch.head + i].pos = -1;
354+
}
355+
}
352356
}
353357

354-
cells[ubatch.head + i].pos = -1;
355-
}
356-
}
357-
}
358+
if (new_head != size && new_head < head) {
359+
head = new_head;
360+
}
358361

359-
if (new_head != size && new_head < head) {
360-
head = new_head;
362+
} break;
363+
case LLAMA_SWA_TYPE_STANDARD:
364+
case LLAMA_SWA_TYPE_CHUNKED:
365+
{
366+
if (!pending.cells_org.empty()) {
367+
cells = std::move(pending.cells_org);
368+
used = pending.used_org;
369+
}
370+
} break;
361371
}
362372

363373
pending.clear();
@@ -460,16 +470,23 @@ void llama_kv_cache_unified::set_full() {
460470
head = 0;
461471
}
462472

463-
llama_sbatch llama_kv_cache_unified::sbatch_init(
464-
const llama_batch & batch,
465-
bool logits_all) {
473+
llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
474+
switch (swa_type) {
475+
case LLAMA_SWA_TYPE_NONE:
476+
{
477+
} break;
478+
case LLAMA_SWA_TYPE_STANDARD:
479+
case LLAMA_SWA_TYPE_CHUNKED:
480+
{
481+
pending.cells_org = cells;
482+
pending.used_org = used;
483+
} break;
484+
}
485+
466486
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
467487
}
468488

469-
llama_ubatch llama_kv_cache_unified::ubatch_next(
470-
llama_sbatch & sbatch,
471-
uint32_t n_ubatch,
472-
bool embd_pooled) const {
489+
llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
473490
GGML_UNUSED(embd_pooled);
474491
return sbatch.split_simple(n_ubatch);
475492
}
@@ -642,6 +659,33 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
642659
return ggml_cpy(ctx, v_cur, v_view);
643660
}
644661

662+
void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos p1) {
663+
GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE);
664+
665+
for (uint32_t i = 0; i < size; ++i) {
666+
const llama_pos p0 = cells[i].pos;
667+
668+
if (is_masked_swa(p0, p1)) {
669+
if (seq_id < 0) {
670+
cells[i].seq_id.clear();
671+
} else if (cells[i].has_seq_id(seq_id)) {
672+
cells[i].seq_id.erase(seq_id);
673+
} else {
674+
continue;
675+
}
676+
677+
if (cells[i].is_empty()) {
678+
// keep count of the number of used cells
679+
if (cells[i].pos >= 0) {
680+
used--;
681+
}
682+
683+
cells[i].pos = -1;
684+
}
685+
}
686+
}
687+
}
688+
645689
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
646690
const int64_t n_tokens = ubatch->n_tokens;
647691
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
@@ -1589,13 +1633,13 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15891633
bool offload,
15901634
uint32_t kv_size,
15911635
uint32_t n_seq_max,
1592-
uint32_t n_batch,
1636+
uint32_t n_ubatch,
15931637
uint32_t padding) : hparams(model.hparams) {
15941638
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
15951639
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
15961640

15971641
const uint32_t kv_size_base = kv_size;
1598-
const uint32_t kv_size_swa = std::min(kv_size, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, padding));
1642+
const uint32_t kv_size_swa = std::min(kv_size, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, padding));
15991643

16001644
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, kv_size_base);
16011645

@@ -1658,21 +1702,6 @@ void llama_kv_cache_unified_iswa::restore() {
16581702
void llama_kv_cache_unified_iswa::commit() {
16591703
kv_base->commit();
16601704
kv_swa ->commit();
1661-
1662-
if (pending.pos_max.empty()) {
1663-
return;
1664-
}
1665-
1666-
// slide the attention window, forgetting/pruning old tokens that are outside the window
1667-
for (const auto & [seq_id, pos_max] : pending.pos_max) {
1668-
if (pos_max <= (llama_pos) hparams.n_swa) {
1669-
continue;
1670-
}
1671-
1672-
kv_swa->seq_rm(seq_id, -1, pos_max - hparams.n_swa + 1);
1673-
}
1674-
1675-
pending.pos_max.clear();
16761705
}
16771706

16781707
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
@@ -1695,21 +1724,30 @@ void llama_kv_cache_unified_iswa::set_full() {
16951724
}
16961725

16971726
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1698-
for (int i = 0; i < batch.n_tokens; ++i) {
1699-
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
1700-
const llama_seq_id seq_id = batch.seq_id[i][s];
1701-
const llama_pos pos = batch.pos[i];
1727+
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
1728+
}
1729+
1730+
llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
1731+
GGML_UNUSED(embd_pooled);
1732+
auto res = sbatch.split_simple(n_ubatch);
17021733

1703-
pending.pos_max[seq_id] = std::max(pending.pos_max[seq_id], pos);
1734+
pos_max_per_seq.clear();
1735+
1736+
for (uint32_t i = 0; i < res.n_tokens; ++i) {
1737+
for (int s = 0; s < res.n_seq_id[i]; ++s) {
1738+
const llama_seq_id seq_id = res.seq_id[i][s];
1739+
const llama_pos pos = res.pos[i];
1740+
1741+
pos_max_per_seq[seq_id] = std::max(pos_max_per_seq[seq_id], pos);
17041742
}
17051743
}
17061744

1707-
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
1708-
}
1745+
// slide the attention window, forgetting/pruning old tokens that are outside the window
1746+
for (const auto & [seq_id, pos_max] : pos_max_per_seq) {
1747+
kv_swa->prune_swa(seq_id, pos_max);
1748+
}
17091749

1710-
llama_ubatch llama_kv_cache_unified_iswa::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
1711-
GGML_UNUSED(embd_pooled);
1712-
return sbatch.split_simple(n_ubatch);
1750+
return res;
17131751
}
17141752

17151753
bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
@@ -2122,7 +2160,7 @@ llama_sbatch llama_kv_cache_recurrent::sbatch_init(
21222160
return llama_sbatch(batch, hparams.n_embd, false, logits_all);
21232161
}
21242162

2125-
llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2163+
llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) {
21262164
if (embd_pooled) {
21272165
// Pooled embeddings cannot be split across ubatches (yet)
21282166
return sbatch.split_seq(n_ubatch);

src/llama-kv-cache.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ struct llama_kv_cache : public llama_memory_i {
4343
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
4444

4545
// different KV caches require different batch splitting strategies
46-
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
46+
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) = 0;
4747

4848
// find an empty slot of size "n_tokens" in the cache
4949
virtual bool find_slot(const llama_ubatch & batch) = 0;
@@ -136,7 +136,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
136136
void set_full() override;
137137

138138
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
139-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
139+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) override;
140140

141141
// updates the cache head
142142
// Note: On success, it's important that cache.head points
@@ -171,6 +171,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
171171
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
172172
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
173173

174+
void prune_swa(llama_seq_id seq_id, llama_pos p1);
175+
174176
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
175177
void set_input_k_shift (ggml_tensor * dst) const;
176178
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@@ -249,10 +251,15 @@ class llama_kv_cache_unified : public llama_kv_cache {
249251
struct {
250252
void clear() {
251253
ubatches.clear();
254+
cells_org.clear();
252255
}
253256

254257
// upon batch processing failure, we revert these ubatches from the KV cells
255258
std::vector<ubatch_info> ubatches;
259+
260+
uint32_t used_org;
261+
262+
std::vector<kv_cell> cells_org;
256263
} pending;
257264

258265
// defrag
@@ -317,7 +324,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
317324
bool offload,
318325
uint32_t kv_size,
319326
uint32_t n_seq_max,
320-
uint32_t n_batch,
327+
uint32_t n_ubatch,
321328
uint32_t padding);
322329

323330
~llama_kv_cache_unified_iswa() = default;
@@ -350,7 +357,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
350357
void set_full() override;
351358

352359
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
353-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
360+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) override;
354361

355362
bool find_slot(const llama_ubatch & batch) override;
356363

@@ -377,10 +384,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
377384
private:
378385
const llama_hparams & hparams;
379386

380-
// pending cell updates that are not yet committed
381-
struct {
382-
std::map<llama_seq_id, llama_pos> pos_max;
383-
} pending;
387+
std::map<llama_seq_id, llama_pos> pos_max_per_seq;
384388

385389
std::unique_ptr<llama_kv_cache_unified> kv_base;
386390
std::unique_ptr<llama_kv_cache_unified> kv_swa;
@@ -449,7 +453,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
449453
void set_full() override;
450454

451455
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
452-
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
456+
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) override;
453457

454458
bool find_slot(const llama_ubatch & batch) override;
455459

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13228,7 +13228,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1322813228
cparams.offload_kqv,
1322913229
cparams.n_ctx,
1323013230
cparams.n_seq_max,
13231-
cparams.n_batch,
13231+
cparams.n_ubatch,
1323213232
padding);
1323313233
} else {
1323413234
res = new llama_kv_cache_unified(

0 commit comments

Comments
 (0)