Skip to content

Commit 9bcee62

Browse files
committed
kv-cache : fix bug with n_ubatch < n_batch
ggml-ci
1 parent b02a5a8 commit 9bcee62

File tree

3 files changed

+156
-67
lines changed

3 files changed

+156
-67
lines changed

src/llama-context.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ llama_context::llama_context(
9393
}
9494

9595
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
96+
9697
cparams.op_offload = params.op_offload;
9798

9899
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@@ -916,6 +917,9 @@ int llama_context::decode(llama_batch & inp_batch) {
916917
return -2;
917918
};
918919

920+
// =============================================================================================================
921+
// TODO: refactor the llama_kv_cache interface and simplify this
922+
919923
// handle any pending defrags/shifts
920924
kv_self_update();
921925

@@ -967,11 +971,16 @@ int llama_context::decode(llama_batch & inp_batch) {
967971
ubatches.clear();
968972
}
969973

974+
// =============================================================================================================
975+
970976
// we now have prepared the ubatches for this llama_decode and are ready to start processing
971977

972978
int64_t n_outputs_prev = 0;
973979

974-
for (const auto & ubatch : ubatches) {
980+
for (int i = 0; i < (int) ubatches.size(); ++i) {
981+
const auto & ubatch = ubatches[i];
982+
kv_self->set_state(i);
983+
975984
// count the outputs in this u_batch
976985
{
977986
int32_t n_outputs_new = 0;

src/llama-kv-cache.cpp

Lines changed: 109 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -333,44 +333,32 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
333333
}
334334

335335
void llama_kv_cache_unified::restore() {
336-
if (pending.ubatches.empty()) {
337-
return;
338-
}
339-
340-
uint32_t new_head = size;
341-
342-
for (const auto & ubatch : pending.ubatches) {
343-
for (uint32_t i = 0; i < ubatch.data.n_tokens; ++i) {
344-
for (int s = 0; s < ubatch.data.n_seq_id[i]; ++s) {
345-
const llama_seq_id seq_id = ubatch.data.seq_id[i][s];
336+
for (const auto & [id, cell] : recovery.cells) {
337+
// TODO: move to new `struct kv_cells`
338+
const bool is_empty0 = cells[id].is_empty();
339+
const bool is_empty1 = cell.is_empty();
346340

347-
cells[ubatch.head + i].seq_id.erase(seq_id);
348-
if (cells[ubatch.head + i].seq_id.empty()) {
349-
used--;
350-
351-
new_head = std::min(new_head, ubatch.head + i);
352-
}
353-
354-
cells[ubatch.head + i].pos = -1;
355-
}
341+
if (!is_empty0 && is_empty1) {
342+
used--;
343+
} else if (is_empty0 && !is_empty1) {
344+
used++;
356345
}
357-
}
358346

359-
if (new_head != size && new_head < head) {
360-
head = new_head;
347+
cells[id] = cell;
361348
}
362349

363-
pending.clear();
350+
recovery.clear();
351+
states.clear();
364352
}
365353

366354
void llama_kv_cache_unified::commit() {
367-
if (pending.ubatches.empty()) {
368-
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
369-
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
355+
if (recovery.cells.empty()) {
356+
LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
357+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
370358
return;
371359
}
372360

373-
pending.clear();
361+
recovery.clear();
374362
}
375363

376364
bool llama_kv_cache_unified::update(llama_context & lctx) {
@@ -460,16 +448,11 @@ void llama_kv_cache_unified::set_full() {
460448
head = 0;
461449
}
462450

463-
llama_sbatch llama_kv_cache_unified::sbatch_init(
464-
const llama_batch & batch,
465-
bool logits_all) {
451+
llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
466452
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
467453
}
468454

469-
llama_ubatch llama_kv_cache_unified::ubatch_next(
470-
llama_sbatch & sbatch,
471-
uint32_t n_ubatch,
472-
bool embd_pooled) const {
455+
llama_ubatch llama_kv_cache_unified::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
473456
GGML_UNUSED(embd_pooled);
474457
return sbatch.split_simple(n_ubatch);
475458
}
@@ -490,6 +473,29 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
490473
return false;
491474
}
492475

476+
//#define FIND_SLOT_DEBUG 1
477+
#if FIND_SLOT_DEBUG
478+
LLAMA_LOG_WARN("begin: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
479+
480+
// for debugging
481+
{
482+
std::string ss;
483+
if (n_swa > 0) {
484+
for (uint32_t i = 0; i < size; ++i) {
485+
if (cells[i].pos == -1) {
486+
ss += '.';
487+
} else {
488+
ss += std::to_string(*cells[i].seq_id.begin());
489+
}
490+
if (i%256 == 255) {
491+
ss += '\n';
492+
}
493+
}
494+
}
495+
LLAMA_LOG_WARN("\n%s\n", ss.c_str());
496+
}
497+
#endif
498+
493499
uint32_t n_tested = 0;
494500

495501
while (true) {
@@ -520,6 +526,11 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
520526
}
521527

522528
for (uint32_t i = 0; i < n_tokens; ++i) {
529+
// remember the original state
530+
if (recovery.cells.find(head + i) == recovery.cells.end()) {
531+
recovery.cells[head + i] = cells[head + i];
532+
}
533+
523534
cells[head + i].pos = ubatch.pos[i];
524535

525536
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
@@ -529,18 +540,25 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
529540

530541
used += n_tokens;
531542

532-
pending.ubatches.push_back({ head, ubatch });
533-
534543
// a heuristic, to avoid attending the full cache if it is not yet utilized
535544
// after enough generations, the benefit from this heuristic disappears
536545
// if we start defragmenting the cache, the benefit from this will be more important
537546
n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
538547

539-
//printf("n = %5d, used = %5d, head = %5d\n", n, used, head);
548+
states.push_back({head, n});
549+
550+
#ifdef FIND_SLOT_DEBUG
551+
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
552+
#endif
540553

541554
return true;
542555
}
543556

557+
void llama_kv_cache_unified::set_state(int i) {
558+
head = states[i].head;
559+
n = states[i].n;
560+
}
561+
544562
int32_t llama_kv_cache_unified::get_n_tokens() const {
545563
int32_t result = 0;
546564

@@ -642,6 +660,34 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
642660
return ggml_cpy(ctx, v_cur, v_view);
643661
}
644662

663+
void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos p1) {
664+
// no pruning is needed when the cache does not use SWA
665+
GGML_ASSERT(swa_type != LLAMA_SWA_TYPE_NONE && "do not prune non-SWA cache");
666+
667+
for (uint32_t i = 0; i < size; ++i) {
668+
const llama_pos p0 = cells[i].pos;
669+
670+
if (is_masked_swa(p0, p1)) {
671+
if (seq_id < 0) {
672+
cells[i].seq_id.clear();
673+
} else if (cells[i].has_seq_id(seq_id)) {
674+
cells[i].seq_id.erase(seq_id);
675+
} else {
676+
continue;
677+
}
678+
679+
if (cells[i].is_empty()) {
680+
// keep count of the number of used cells
681+
if (cells[i].pos >= 0) {
682+
used--;
683+
}
684+
685+
cells[i].pos = -1;
686+
}
687+
}
688+
}
689+
}
690+
645691
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
646692
const int64_t n_tokens = ubatch->n_tokens;
647693
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
@@ -1181,6 +1227,10 @@ uint32_t llama_kv_cache_unified::cell_max() const {
11811227
}
11821228

11831229
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
1230+
if (p0 < 0) {
1231+
return true;
1232+
}
1233+
11841234
switch (swa_type) {
11851235
case LLAMA_SWA_TYPE_NONE:
11861236
{
@@ -1653,26 +1703,19 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
16531703
void llama_kv_cache_unified_iswa::restore() {
16541704
kv_base->restore();
16551705
kv_swa ->restore();
1706+
states.clear();
16561707
}
16571708

16581709
void llama_kv_cache_unified_iswa::commit() {
16591710
kv_base->commit();
16601711
kv_swa ->commit();
16611712

1662-
if (pending.pos_max.empty()) {
1663-
return;
1664-
}
1665-
16661713
// slide the attention window, forgetting/pruning old tokens that are outside the window
16671714
for (const auto & [seq_id, pos_max] : pending.pos_max) {
1668-
if (pos_max <= (llama_pos) hparams.n_swa) {
1669-
continue;
1670-
}
1671-
1672-
kv_swa->seq_rm(seq_id, -1, pos_max - hparams.n_swa + 1);
1715+
kv_swa->prune_swa(seq_id, pos_max);
16731716
}
16741717

1675-
pending.pos_max.clear();
1718+
pending.clear();
16761719
}
16771720

16781721
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
@@ -1695,12 +1738,18 @@ void llama_kv_cache_unified_iswa::set_full() {
16951738
}
16961739

16971740
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
1741+
pending.pos_max.clear();
1742+
16981743
for (int i = 0; i < batch.n_tokens; ++i) {
16991744
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
17001745
const llama_seq_id seq_id = batch.seq_id[i][s];
17011746
const llama_pos pos = batch.pos[i];
17021747

1703-
pending.pos_max[seq_id] = std::max(pending.pos_max[seq_id], pos);
1748+
if (pending.pos_max.find(seq_id) == pending.pos_max.end()) {
1749+
pending.pos_max[seq_id] = pos;
1750+
} else {
1751+
pending.pos_max[seq_id] = std::max(pending.pos_max[seq_id], pos);
1752+
}
17041753
}
17051754
}
17061755

@@ -1721,6 +1770,11 @@ bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
17211770
return res;
17221771
}
17231772

1773+
void llama_kv_cache_unified_iswa::set_state(int i) {
1774+
kv_base->set_state(i);
1775+
kv_swa ->set_state(i);
1776+
}
1777+
17241778
int32_t llama_kv_cache_unified_iswa::get_n_tokens() const {
17251779
return kv_base->get_n_tokens();
17261780
}
@@ -2090,6 +2144,8 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
20902144
}
20912145

20922146
void llama_kv_cache_recurrent::restore() {
2147+
states.clear();
2148+
20932149
if (pending.ranges.empty()) {
20942150
return;
20952151
}
@@ -2306,6 +2362,11 @@ bool llama_kv_cache_recurrent::find_slot(
23062362
return n >= n_seqs;
23072363
}
23082364

2365+
void llama_kv_cache_recurrent::set_state(int i) {
2366+
head = states[i].head;
2367+
n = states[i].n;
2368+
}
2369+
23092370
int32_t llama_kv_cache_recurrent::get_n_tokens() const {
23102371
int32_t result = 0;
23112372

0 commit comments

Comments
 (0)