Skip to content

Commit 27057f2

Browse files
committed
kv-cache : rework recovery logic + restrict SWA batch params
ggml-ci
1 parent b02a5a8 commit 27057f2

File tree

4 files changed

+122
-70
lines changed

4 files changed

+122
-70
lines changed

src/llama-context.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,14 @@ llama_context::llama_context(
9393
}
9494

9595
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
96+
97+
// TODO: allowing this requires major refactor of the KV cache logic
98+
// ref: https://github.com/ggml-org/llama.cpp/pull/13194
99+
if (hparams.n_swa > 0 && cparams.n_ubatch < cparams.n_batch) {
100+
LLAMA_LOG_WARN("%s: SWA models currently do not support n_ubatch < n_batch - increasing n_ubatch to %d\n", __func__, cparams.n_batch);
101+
cparams.n_ubatch = cparams.n_batch;
102+
}
103+
96104
cparams.op_offload = params.op_offload;
97105

98106
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@@ -916,6 +924,9 @@ int llama_context::decode(llama_batch & inp_batch) {
916924
return -2;
917925
};
918926

927+
// =============================================================================================================
928+
// TODO: refactor and simplify this
929+
919930
// handle any pending defrags/shifts
920931
kv_self_update();
921932

@@ -967,6 +978,8 @@ int llama_context::decode(llama_batch & inp_batch) {
967978
ubatches.clear();
968979
}
969980

981+
// =============================================================================================================
982+
970983
// we now have prepared the ubatches for this llama_decode and are ready to start processing
971984

972985
int64_t n_outputs_prev = 0;

src/llama-kv-cache.cpp

Lines changed: 90 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -333,44 +333,31 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
333333
}
334334

335335
void llama_kv_cache_unified::restore() {
336-
if (pending.ubatches.empty()) {
337-
return;
338-
}
339-
340-
uint32_t new_head = size;
341-
342-
for (const auto & ubatch : pending.ubatches) {
343-
for (uint32_t i = 0; i < ubatch.data.n_tokens; ++i) {
344-
for (int s = 0; s < ubatch.data.n_seq_id[i]; ++s) {
345-
const llama_seq_id seq_id = ubatch.data.seq_id[i][s];
346-
347-
cells[ubatch.head + i].seq_id.erase(seq_id);
348-
if (cells[ubatch.head + i].seq_id.empty()) {
349-
used--;
350-
351-
new_head = std::min(new_head, ubatch.head + i);
352-
}
336+
for (const auto & [id, cell] : recovery.cells) {
337+
// TODO: move to new `struct kv_cells`
338+
const bool is_empty0 = cells[id].is_empty();
339+
const bool is_empty1 = cell.is_empty();
353340

354-
cells[ubatch.head + i].pos = -1;
355-
}
341+
if (!is_empty0 && is_empty1) {
342+
used--;
343+
} else if (is_empty0 && !is_empty1) {
344+
used++;
356345
}
357-
}
358346

359-
if (new_head != size && new_head < head) {
360-
head = new_head;
347+
cells[id] = cell;
361348
}
362349

363-
pending.clear();
350+
recovery.clear();
364351
}
365352

366353
void llama_kv_cache_unified::commit() {
367-
if (pending.ubatches.empty()) {
368-
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
369-
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
354+
if (recovery.cells.empty()) {
355+
LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
356+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
370357
return;
371358
}
372359

373-
pending.clear();
360+
recovery.clear();
374361
}
375362

376363
bool llama_kv_cache_unified::update(llama_context & lctx) {
@@ -460,16 +447,11 @@ void llama_kv_cache_unified::set_full() {
460447
head = 0;
461448
}
462449

463-
llama_sbatch llama_kv_cache_unified::sbatch_init(
464-
const llama_batch & batch,
465-
bool logits_all) {
450+
llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
466451
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
467452
}
468453

469-
llama_ubatch llama_kv_cache_unified::ubatch_next(
470-
llama_sbatch & sbatch,
471-
uint32_t n_ubatch,
472-
bool embd_pooled) const {
454+
llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
473455
GGML_UNUSED(embd_pooled);
474456
return sbatch.split_simple(n_ubatch);
475457
}
@@ -490,6 +472,29 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
490472
return false;
491473
}
492474

475+
//#define FIND_SLOT_DEBUG 1
476+
#if FIND_SLOT_DEBUG
477+
LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
478+
479+
// for debugging
480+
{
481+
std::string ss;
482+
if (n_swa > 0) {
483+
for (uint32_t i = 0; i < size; ++i) {
484+
if (cells[i].pos == -1) {
485+
ss += '.';
486+
} else {
487+
ss += std::to_string(*cells[i].seq_id.begin());
488+
}
489+
if (i%256 == 255) {
490+
ss += '\n';
491+
}
492+
}
493+
}
494+
LLAMA_LOG_WARN("\n%s\n", ss.c_str());
495+
}
496+
#endif
497+
493498
uint32_t n_tested = 0;
494499

495500
while (true) {
@@ -520,6 +525,11 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
520525
}
521526

522527
for (uint32_t i = 0; i < n_tokens; ++i) {
528+
// remember the original state
529+
if (recovery.cells.find(head + i) == recovery.cells.end()) {
530+
recovery.cells[head + i] = cells[head + i];
531+
}
532+
523533
cells[head + i].pos = ubatch.pos[i];
524534

525535
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
@@ -529,14 +539,14 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
529539

530540
used += n_tokens;
531541

532-
pending.ubatches.push_back({ head, ubatch });
533-
534542
// a heuristic, to avoid attending the full cache if it is not yet utilized
535543
// after enough generations, the benefit from this heuristic disappears
536544
// if we start defragmenting the cache, the benefit from this will be more important
537545
n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
538546

539-
//printf("n = %5d, used = %5d, head = %5d\n", n, used, head);
547+
#ifdef FIND_SLOT_DEBUG
548+
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
549+
#endif
540550

541551
return true;
542552
}
@@ -642,6 +652,34 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
642652
return ggml_cpy(ctx, v_cur, v_view);
643653
}
644654

655+
void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos p1) {
656+
// no pruning is needed when the cache does not use SWA
657+
GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");
658+
659+
for (uint32_t i = 0; i < size; ++i) {
660+
const llama_pos p0 = cells[i].pos;
661+
662+
if (is_masked_swa(p0, p1)) {
663+
if (seq_id < 0) {
664+
cells[i].seq_id.clear();
665+
} else if (cells[i].has_seq_id(seq_id)) {
666+
cells[i].seq_id.erase(seq_id);
667+
} else {
668+
continue;
669+
}
670+
671+
if (cells[i].is_empty()) {
672+
// keep count of the number of used cells
673+
if (cells[i].pos >= 0) {
674+
used--;
675+
}
676+
677+
cells[i].pos = -1;
678+
}
679+
}
680+
}
681+
}
682+
645683
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
646684
const int64_t n_tokens = ubatch->n_tokens;
647685
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
@@ -1181,6 +1219,10 @@ uint32_t llama_kv_cache_unified::cell_max() const {
11811219
}
11821220

11831221
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
1222+
if (p0 < 0) {
1223+
return true;
1224+
}
1225+
11841226
switch (swa_type) {
11851227
case LLAMA_SWA_TYPE_NONE:
11861228
{
@@ -1589,13 +1631,13 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
15891631
bool offload,
15901632
uint32_t kv_size,
15911633
uint32_t n_seq_max,
1592-
uint32_t n_batch,
1634+
uint32_t n_ubatch,
15931635
uint32_t padding) : hparams(model.hparams) {
15941636
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
15951637
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
15961638

15971639
const uint32_t kv_size_base = kv_size;
1598-
const uint32_t kv_size_swa = std::min(kv_size, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, padding));
1640+
const uint32_t kv_size_swa = std::min(kv_size, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, padding));
15991641

16001642
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, kv_size_base);
16011643

@@ -1659,20 +1701,12 @@ void llama_kv_cache_unified_iswa::commit() {
16591701
kv_base->commit();
16601702
kv_swa ->commit();
16611703

1662-
if (pending.pos_max.empty()) {
1663-
return;
1664-
}
1665-
16661704
// slide the attention window, forgetting/pruning old tokens that are outside the window
16671705
for (const auto & [seq_id, pos_max] : pending.pos_max) {
1668-
if (pos_max <= (llama_pos) hparams.n_swa) {
1669-
continue;
1670-
}
1671-
1672-
kv_swa->seq_rm(seq_id, -1, pos_max - hparams.n_swa + 1);
1706+
kv_swa->prune_swa(seq_id, pos_max);
16731707
}
16741708

1675-
pending.pos_max.clear();
1709+
pending.clear();
16761710
}
16771711

16781712
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
@@ -1695,12 +1729,18 @@ void llama_kv_cache_unified_iswa::set_full() {
16951729
}
16961730

16971731
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1732+
pending.pos_max.clear();
1733+
16981734
for (int i = 0; i < batch.n_tokens; ++i) {
16991735
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
17001736
const llama_seq_id seq_id = batch.seq_id[i][s];
17011737
const llama_pos pos = batch.pos[i];
17021738

1703-
pending.pos_max[seq_id] = std::max(pending.pos_max[seq_id], pos);
1739+
if (pending.pos_max.find(seq_id) == pending.pos_max.end()) {
1740+
pending.pos_max[seq_id] = pos;
1741+
} else {
1742+
pending.pos_max[seq_id] = std::max(pending.pos_max[seq_id], pos);
1743+
}
17041744
}
17051745
}
17061746

src/llama-kv-cache.h

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22

33
#include "llama.h"
44
#include "llama-io.h"
5-
#include "llama-batch.h"
65
#include "llama-graph.h"
76
#include "llama-memory.h"
87

98
#include "ggml-cpp.h"
109

11-
#include <map>
1210
#include <set>
11+
#include <unordered_map>
1312
#include <vector>
1413

1514
struct llama_cparams;
1615
struct llama_hparams;
16+
struct llama_ubatch;
17+
struct llama_sbatch;
1718
struct llama_model;
1819
struct llama_context;
1920

@@ -171,6 +172,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
171172
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
172173
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
173174

175+
void prune_swa(llama_seq_id seq_id, llama_pos p1);
176+
174177
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
175178
void set_input_k_shift (ggml_tensor * dst) const;
176179
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@@ -214,7 +217,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
214217

215218
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
216219
uint32_t size = 0; // total number of cells, shared across all sequences
217-
uint32_t used = 0; // used cells (i.e. at least one seq_id)
220+
uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt)
218221

219222
// computed before each graph build
220223
uint32_t n = 0;
@@ -233,27 +236,20 @@ class llama_kv_cache_unified : public llama_kv_cache {
233236
std::vector<ggml_context_ptr> ctxs;
234237
std::vector<ggml_backend_buffer_ptr> bufs;
235238

236-
std::vector<kv_cell> cells;
239+
std::vector<kv_cell> cells; // TODO: replace with `struct kv_cells`
237240
std::vector<kv_layer> layers;
238241

239242
// model layer id -> KV cache layer id
240-
std::map<int32_t, int32_t> map_layer_ids;
241-
242-
struct ubatch_info {
243-
uint32_t head;
243+
std::unordered_map<int32_t, int32_t> map_layer_ids;
244244

245-
llama_ubatch data;
246-
};
247-
248-
// pending cell updates that are not yet committed
245+
// recovery information used to restore the KV cells to their original state in case of a failure
249246
struct {
250247
void clear() {
251-
ubatches.clear();
248+
cells.clear();
252249
}
253250

254-
// upon batch processing failure, we revert these ubatches from the KV cells
255-
std::vector<ubatch_info> ubatches;
256-
} pending;
251+
std::unordered_map<uint32_t, kv_cell> cells;
252+
} recovery;
257253

258254
// defrag
259255
struct {
@@ -317,7 +313,7 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
317313
bool offload,
318314
uint32_t kv_size,
319315
uint32_t n_seq_max,
320-
uint32_t n_batch,
316+
uint32_t n_ubatch,
321317
uint32_t padding);
322318

323319
~llama_kv_cache_unified_iswa() = default;
@@ -377,9 +373,12 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
377373
private:
378374
const llama_hparams & hparams;
379375

380-
// pending cell updates that are not yet committed
381376
struct {
382-
std::map<llama_seq_id, llama_pos> pos_max;
377+
void clear() {
378+
pos_max.clear();
379+
}
380+
381+
std::unordered_map<llama_seq_id, llama_pos> pos_max;
383382
} pending;
384383

385384
std::unique_ptr<llama_kv_cache_unified> kv_base;

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13228,7 +13228,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1322813228
cparams.offload_kqv,
1322913229
cparams.n_ctx,
1323013230
cparams.n_seq_max,
13231-
cparams.n_batch,
13231+
cparams.n_ubatch,
1323213232
padding);
1323313233
} else {
1323413234
res = new llama_kv_cache_unified(

0 commit comments

Comments
 (0)