Skip to content

Commit 50bc961

Browse files
committed
kv-cache : fix bug with n_ubatch < n_batch
ggml-ci
1 parent b02a5a8 commit 50bc961

File tree

3 files changed

+159
-67
lines changed

3 files changed

+159
-67
lines changed

src/llama-context.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ llama_context::llama_context(
9393
}
9494

9595
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
96+
9697
cparams.op_offload = params.op_offload;
9798

9899
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@@ -916,6 +917,9 @@ int llama_context::decode(llama_batch & inp_batch) {
916917
return -2;
917918
};
918919

920+
// =============================================================================================================
921+
// TODO: refactor the llama_kv_cache interface and simplify this
922+
919923
// handle any pending defrags/shifts
920924
kv_self_update();
921925

@@ -967,11 +971,16 @@ int llama_context::decode(llama_batch & inp_batch) {
967971
ubatches.clear();
968972
}
969973

974+
// =============================================================================================================
975+
970976
// we now have prepared the ubatches for this llama_decode and are ready to start processing
971977

972978
int64_t n_outputs_prev = 0;
973979

974-
for (const auto & ubatch : ubatches) {
980+
for (int i = 0; i < (int) ubatches.size(); ++i) {
981+
const auto & ubatch = ubatches[i];
982+
kv_self->set_state(i);
983+
975984
// count the outputs in this u_batch
976985
{
977986
int32_t n_outputs_new = 0;

src/llama-kv-cache.cpp

Lines changed: 112 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -333,44 +333,33 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
333333
}
334334

335335
void llama_kv_cache_unified::restore() {
336-
if (pending.ubatches.empty()) {
337-
return;
338-
}
339-
340-
uint32_t new_head = size;
341-
342-
for (const auto & ubatch : pending.ubatches) {
343-
for (uint32_t i = 0; i < ubatch.data.n_tokens; ++i) {
344-
for (int s = 0; s < ubatch.data.n_seq_id[i]; ++s) {
345-
const llama_seq_id seq_id = ubatch.data.seq_id[i][s];
336+
for (const auto & [id, cell] : recovery.cells) {
337+
// TODO: move to new `struct kv_cells`
338+
const bool is_empty0 = cells[id].is_empty();
339+
const bool is_empty1 = cell.is_empty();
346340

347-
cells[ubatch.head + i].seq_id.erase(seq_id);
348-
if (cells[ubatch.head + i].seq_id.empty()) {
349-
used--;
350-
351-
new_head = std::min(new_head, ubatch.head + i);
352-
}
353-
354-
cells[ubatch.head + i].pos = -1;
355-
}
341+
if (!is_empty0 && is_empty1) {
342+
used--;
343+
} else if (is_empty0 && !is_empty1) {
344+
used++;
356345
}
357-
}
358346

359-
if (new_head != size && new_head < head) {
360-
head = new_head;
347+
cells[id] = cell;
361348
}
362349

363-
pending.clear();
350+
recovery.clear();
351+
states.clear();
364352
}
365353

366354
void llama_kv_cache_unified::commit() {
367-
if (pending.ubatches.empty()) {
368-
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
369-
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
355+
if (recovery.cells.empty()) {
356+
LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
357+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
370358
return;
371359
}
372360

373-
pending.clear();
361+
recovery.clear();
362+
states.clear();
374363
}
375364

376365
bool llama_kv_cache_unified::update(llama_context & lctx) {
@@ -460,16 +449,11 @@ void llama_kv_cache_unified::set_full() {
460449
head = 0;
461450
}
462451

463-
llama_sbatch llama_kv_cache_unified::sbatch_init(
464-
const llama_batch & batch,
465-
bool logits_all) {
452+
llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
466453
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
467454
}
468455

469-
llama_ubatch llama_kv_cache_unified::ubatch_next(
470-
llama_sbatch & sbatch,
471-
uint32_t n_ubatch,
472-
bool embd_pooled) const {
456+
llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
473457
GGML_UNUSED(embd_pooled);
474458
return sbatch.split_simple(n_ubatch);
475459
}
@@ -490,6 +474,29 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
490474
return false;
491475
}
492476

477+
//#define FIND_SLOT_DEBUG 1
478+
#if FIND_SLOT_DEBUG
479+
LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
480+
481+
// for debugging
482+
{
483+
std::string ss;
484+
if (n_swa > 0) {
485+
for (uint32_t i = 0; i < size; ++i) {
486+
if (cells[i].pos == -1) {
487+
ss += '.';
488+
} else {
489+
ss += std::to_string(*cells[i].seq_id.begin());
490+
}
491+
if (i%256 == 255) {
492+
ss += '\n';
493+
}
494+
}
495+
}
496+
LLAMA_LOG_WARN("\n%s\n", ss.c_str());
497+
}
498+
#endif
499+
493500
uint32_t n_tested = 0;
494501

495502
while (true) {
@@ -520,6 +527,11 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
520527
}
521528

522529
for (uint32_t i = 0; i < n_tokens; ++i) {
530+
// remember the original state
531+
if (recovery.cells.find(head + i) == recovery.cells.end()) {
532+
recovery.cells[head + i] = cells[head + i];
533+
}
534+
523535
cells[head + i].pos = ubatch.pos[i];
524536

525537
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
@@ -529,18 +541,25 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
529541

530542
used += n_tokens;
531543

532-
pending.ubatches.push_back({ head, ubatch });
533-
534544
// a heuristic, to avoid attending the full cache if it is not yet utilized
535545
// after enough generations, the benefit from this heuristic disappears
536546
// if we start defragmenting the cache, the benefit from this will be more important
537547
n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
538548

539-
//printf("n = %5d, used = %5d, head = %5d\n", n, used, head);
549+
states.push_back({head, n});
550+
551+
#ifdef FIND_SLOT_DEBUG
552+
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
553+
#endif
540554

541555
return true;
542556
}
543557

558+
void llama_kv_cache_unified::set_state(int i) {
559+
head = states[i].head;
560+
n = states[i].n;
561+
}
562+
544563
int32_t llama_kv_cache_unified::get_n_tokens() const {
545564
int32_t result = 0;
546565

@@ -642,6 +661,34 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
642661
return ggml_cpy(ctx, v_cur, v_view);
643662
}
644663

664+
void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos p1) {
665+
// no pruning is needed when the cache does not use SWA
666+
GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");
667+
668+
for (uint32_t i = 0; i < size; ++i) {
669+
const llama_pos p0 = cells[i].pos;
670+
671+
if (is_masked_swa(p0, p1)) {
672+
if (seq_id < 0) {
673+
cells[i].seq_id.clear();
674+
} else if (cells[i].has_seq_id(seq_id)) {
675+
cells[i].seq_id.erase(seq_id);
676+
} else {
677+
continue;
678+
}
679+
680+
if (cells[i].is_empty()) {
681+
// keep count of the number of used cells
682+
if (cells[i].pos >= 0) {
683+
used--;
684+
}
685+
686+
cells[i].pos = -1;
687+
}
688+
}
689+
}
690+
}
691+
645692
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
646693
const int64_t n_tokens = ubatch->n_tokens;
647694
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
@@ -1181,6 +1228,10 @@ uint32_t llama_kv_cache_unified::cell_max() const {
11811228
}
11821229

11831230
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
1231+
if (p0 < 0) {
1232+
return true;
1233+
}
1234+
11841235
switch (swa_type) {
11851236
case LLAMA_SWA_TYPE_NONE:
11861237
{
@@ -1653,26 +1704,20 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
16531704
void llama_kv_cache_unified_iswa::restore() {
16541705
kv_base->restore();
16551706
kv_swa ->restore();
1707+
states.clear();
16561708
}
16571709

16581710
void llama_kv_cache_unified_iswa::commit() {
16591711
kv_base->commit();
16601712
kv_swa ->commit();
16611713

1662-
if (pending.pos_max.empty()) {
1663-
return;
1664-
}
1665-
16661714
// slide the attention window, forgetting/pruning old tokens that are outside the window
16671715
for (const auto & [seq_id, pos_max] : pending.pos_max) {
1668-
if (pos_max <= (llama_pos) hparams.n_swa) {
1669-
continue;
1670-
}
1671-
1672-
kv_swa->seq_rm(seq_id, -1, pos_max - hparams.n_swa + 1);
1716+
kv_swa->prune_swa(seq_id, pos_max);
16731717
}
16741718

1675-
pending.pos_max.clear();
1719+
pending.clear();
1720+
states.clear();
16761721
}
16771722

16781723
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
@@ -1695,12 +1740,18 @@ void llama_kv_cache_unified_iswa::set_full() {
16951740
}
16961741

16971742
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1743+
pending.pos_max.clear();
1744+
16981745
for (int i = 0; i < batch.n_tokens; ++i) {
16991746
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
17001747
const llama_seq_id seq_id = batch.seq_id[i][s];
17011748
const llama_pos pos = batch.pos[i];
17021749

1703-
pending.pos_max[seq_id] = std::max(pending.pos_max[seq_id], pos);
1750+
if (pending.pos_max.find(seq_id) == pending.pos_max.end()) {
1751+
pending.pos_max[seq_id] = pos;
1752+
} else {
1753+
pending.pos_max[seq_id] = std::max(pending.pos_max[seq_id], pos);
1754+
}
17041755
}
17051756
}
17061757

@@ -1721,6 +1772,11 @@ bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
17211772
return res;
17221773
}
17231774

1775+
void llama_kv_cache_unified_iswa::set_state(int i) {
1776+
kv_base->set_state(i);
1777+
kv_swa ->set_state(i);
1778+
}
1779+
17241780
int32_t llama_kv_cache_unified_iswa::get_n_tokens() const {
17251781
return kv_base->get_n_tokens();
17261782
}
@@ -2090,6 +2146,8 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
20902146
}
20912147

20922148
void llama_kv_cache_recurrent::restore() {
2149+
states.clear();
2150+
20932151
if (pending.ranges.empty()) {
20942152
return;
20952153
}
@@ -2098,6 +2156,7 @@ void llama_kv_cache_recurrent::restore() {
20982156
}
20992157

21002158
void llama_kv_cache_recurrent::commit() {
2159+
states.clear();
21012160
pending.ranges.clear();
21022161
}
21032162

@@ -2306,6 +2365,11 @@ bool llama_kv_cache_recurrent::find_slot(
23062365
return n >= n_seqs;
23072366
}
23082367

2368+
void llama_kv_cache_recurrent::set_state(int i) {
2369+
head = states[i].head;
2370+
n = states[i].n;
2371+
}
2372+
23092373
int32_t llama_kv_cache_recurrent::get_n_tokens() const {
23102374
int32_t result = 0;
23112375

0 commit comments

Comments
 (0)