Skip to content

Commit 91b7792

Browse files
committed
cont : fix position auto-gen + add comments
ggml-ci
1 parent 2437143 commit 91b7792

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

src/llama-batch.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,12 @@ bool llama_batch_allocr::init(
366366

367367
for (int32_t i = 0; i < batch.n_tokens; i++) {
368368
const llama_seq_id seq_id = batch.seq_id[i][0];
369-
pos[i] = p0[seq_id] + i;
369+
370+
pos[i] = p0[seq_id];
371+
372+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
373+
p0[batch.seq_id[i][s]] = pos[i] + 1;
374+
}
370375
}
371376

372377
batch.pos = pos.data();
@@ -397,7 +402,11 @@ bool llama_batch_allocr::init(
397402
const llama_seq_id s0 = batch.seq_id[i][0];
398403
const llama_seq_id s1 = batch.seq_id[i][s];
399404

405+
// mark that sequences s1 is couled to s0
400406
seq_cpl[s1][s0] = true;
407+
408+
// note: the other way around is not necessary for now
409+
//seq_cpl[s0][s1] = true;
401410
}
402411
}
403412
}
@@ -467,6 +476,10 @@ bool llama_batch_allocr::init(
467476
}
468477
}
469478

479+
//
480+
// consistency checks
481+
//
482+
470483
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
471484
if (seq_pos[s].empty()) {
472485
continue;
@@ -478,7 +491,7 @@ bool llama_batch_allocr::init(
478491
}
479492

480493
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
481-
LLAMA_LOG_ERROR("%s: sequence %d is not contiguous\n", __func__, s);
494+
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
482495
return false;
483496
}
484497
}

src/llama-batch.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,13 @@ struct llama_sbatch {
7878
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
7979
};
8080

81-
// temporary allocate memory for the input batch if needed
81+
// a helper for sanitizing and fullfilling a batch
8282
class llama_batch_allocr {
8383
public:
8484
llama_batch_allocr();
8585

86-
// optionally fulfill the batch returned by llama_batch_get_one
86+
// sanitize and auto-gen missing data in the input batch
87+
// memory is optional. if provided will be used to check for sequence continuity
8788
bool init(
8889
const llama_batch & batch_inp,
8990
const llama_vocab & vocab,

0 commit comments

Comments
 (0)