Skip to content

Commit 1786c29

Browse files
ggerganovqnixsynapse
authored andcommitted
kv-cells : fix tracking of seq_pos (ggml-org#14339)
* kv-cells : fix tracking of seq_pos during cache reuse ggml-ci * cont : improve error message ggml-ci * cont : add more comments
1 parent 5b96f59 commit 1786c29

File tree

5 files changed

+56
-17
lines changed

5 files changed

+56
-17
lines changed

include/llama.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -944,12 +944,14 @@ extern "C" {
944944
// Requires the context to have a memory.
945945
// For encode-decoder contexts, processes the batch using the decoder.
946946
// Positive return values does not mean a fatal error, but rather a warning.
947-
// Upon non-zero return values, the memory state is restored to the state before this call
947+
// Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
948+
// To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
949+
// Upon other return values, the memory state is restored to the state before this call
948950
// 0 - success
949951
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
950-
// 2 - aborted
952+
// 2 - aborted (processed ubatches will remain in the context's memory)
951953
// -1 - invalid input batch
952-
// < -1 - error
954+
// < -1 - fatal error (processed ubatches will remain in the context's memory)
953955
LLAMA_API int32_t llama_decode(
954956
struct llama_context * ctx,
955957
struct llama_batch batch);

src/llama-batch.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,21 +245,32 @@ bool llama_batch_allocr::init(
245245
}
246246

247247
if (memory) {
248+
bool ok = true;
249+
248250
if (batch.token) {
249251
if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
250-
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
251-
return false;
252+
ok = false;
252253
}
253254
} else {
254255
assert(batch.embd);
255256

256257
// for embeddings (typically used as vision input), we allow them to have repeating positions
257258
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
258259
if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
259-
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
260-
return false;
260+
ok = false;
261261
}
262262
}
263+
264+
if (!ok) {
265+
LLAMA_LOG_ERROR(
266+
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
267+
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
268+
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
269+
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
270+
__func__, s, s, memory->seq_pos_max(s), s, seq_pos_min(s));
271+
272+
return false;
273+
}
263274
}
264275

265276
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {

src/llama-context.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
10181018
pos_min[s] = std::numeric_limits<llama_pos>::max();
10191019
}
10201020

1021-
// TODO: fix sequence indexing
10221021
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
10231022
const auto & seq_id = ubatch.seq_id[i][0];
10241023

src/llama-kv-cells.h

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <cassert>
88
#include <vector>
99
#include <set>
10+
#include <map>
1011

1112
// meta information about KV cells that can be part of multiple sequences at the same time
1213
// TODO: add unit tests
@@ -216,7 +217,7 @@ class llama_kv_cells_unified {
216217
assert(seq_id >= 0);
217218

218219
seq[i].reset(seq_id);
219-
seq_pos[seq_id].erase(pos[i]);
220+
seq_pos_dec(seq_id, pos[i]);
220221

221222
if (seq[i].none()) {
222223
pos[i] = -1;
@@ -239,7 +240,7 @@ class llama_kv_cells_unified {
239240
seq[i].reset();
240241

241242
seq[i].set(seq_id);
242-
seq_pos[seq_id].insert(pos[i]);
243+
seq_pos_inc(seq_id, pos[i]);
243244

244245
return false;
245246
}
@@ -284,7 +285,7 @@ class llama_kv_cells_unified {
284285
assert(!seq[i].test(seq_id));
285286

286287
seq[i].set(seq_id);
287-
seq_pos[seq_id].insert(pos[i]);
288+
seq_pos_inc(seq_id, pos[i]);
288289
}
289290

290291
// return the sequence id of this cell
@@ -311,7 +312,9 @@ class llama_kv_cells_unified {
311312
return -1;
312313
}
313314

314-
return *seq_pos[seq_id].begin();
315+
assert(seq_pos[seq_id].begin()->second > 0);
316+
317+
return seq_pos[seq_id].begin()->first;
315318
}
316319

317320
// the maximum position of sequence seq_id currently present in any of the cells
@@ -324,7 +327,9 @@ class llama_kv_cells_unified {
324327
return -1;
325328
}
326329

327-
return *seq_pos[seq_id].rbegin();
330+
assert(seq_pos[seq_id].rbegin()->second > 0);
331+
332+
return seq_pos[seq_id].rbegin()->first;
328333
}
329334

330335
// note: call only if the cell is not empty
@@ -441,17 +446,36 @@ class llama_kv_cells_unified {
441446
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
442447
std::vector<seq_set_t> seq;
443448

444-
// the set seq_pos[s] tells us which positions are currently present for sequence s
449+
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
450+
// if the position p is not present, seq_pos[s][p] is not set
445451
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
446-
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
452+
//
453+
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
454+
// - during performing a cache reuse via (rm + add)
455+
// - some vision models have input embeddings with repeating positions
456+
//
457+
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
447458

448459
// helper functions for updating `seq_pos`, once cell at a time:
449460

461+
void seq_pos_dec(llama_seq_id s, llama_pos p) {
462+
auto it = seq_pos[s].find(p);
463+
assert(it != seq_pos[s].end());
464+
465+
if (--it->second == 0) {
466+
seq_pos[s].erase(it);
467+
}
468+
}
469+
470+
void seq_pos_inc(llama_seq_id s, llama_pos p) {
471+
seq_pos[s][p]++;
472+
}
473+
450474
// remove cell i
451475
void seq_pos_rm(uint32_t i) {
452476
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
453477
if (seq[i].test(s)) {
454-
seq_pos[s].erase(pos[i]);
478+
seq_pos_dec(s, pos[i]);
455479
}
456480
}
457481
}
@@ -460,7 +484,7 @@ class llama_kv_cells_unified {
460484
void seq_pos_add(uint32_t i) {
461485
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
462486
if (seq[i].test(s)) {
463-
seq_pos[s].insert(pos[i]);
487+
seq_pos_inc(s, pos[i]);
464488
}
465489
}
466490
}

tools/server/server.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3418,9 +3418,12 @@ struct server_context {
34183418
}
34193419

34203420
if (ret < -1) {
3421+
// TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
34213422
err = "Compute error.";
34223423
}
34233424

3425+
// TODO: handle ret == 2 (abort) when we start aborting
3426+
34243427
if (!err.empty()) {
34253428
SRV_ERR("%s, i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret);
34263429
for (auto & slot : slots) {

0 commit comments

Comments
 (0)