Skip to content

Commit b67fe0d

Browse files
committed
cont : auto-gen positions + verify multi-seq input
ggml-ci
1 parent 42b2ae3 commit b67fe0d

File tree

4 files changed

+84
-36
lines changed

4 files changed

+84
-36
lines changed

src/llama-batch.cpp

Lines changed: 76 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "llama-impl.h"
44
#include "llama-cparams.h"
55
#include "llama-vocab.h"
6+
#include "llama-memory.h"
67

78
#include <cassert>
89
#include <cstring>
@@ -295,7 +296,10 @@ llama_batch_allocr::llama_batch_allocr() {
295296
}
296297
}
297298

298-
bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
299+
bool llama_batch_allocr::init(
300+
const llama_batch & batch_inp,
301+
const llama_vocab & vocab,
302+
const llama_memory_i * memory) {
299303
clear();
300304

301305
batch = batch_inp;
@@ -306,14 +310,6 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
306310
// validate input batch
307311
//
308312

309-
// TODO: remove
310-
if (!batch.pos) {
311-
if (batch.seq_id) {
312-
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
313-
return false;
314-
}
315-
}
316-
317313
if (batch.token) {
318314
for (int32_t i = 0; i < batch.n_tokens; ++i) {
319315
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
@@ -338,15 +334,6 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
338334
// auto-generate missing fields
339335
//
340336

341-
if (!batch.pos) {
342-
assert(p0 >= 0);
343-
pos.resize(batch.n_tokens);
344-
for (int32_t i = 0; i < batch.n_tokens; i++) {
345-
pos[i] = p0 + i;
346-
}
347-
batch.pos = pos.data();
348-
}
349-
350337
if (!batch.n_seq_id) {
351338
n_seq_id.resize(batch.n_tokens);
352339
for (int32_t i = 0; i < batch.n_tokens; i++) {
@@ -364,6 +351,27 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
364351
batch.seq_id = seq_id.data();
365352
}
366353

354+
if (!batch.pos) {
355+
pos.resize(batch.n_tokens);
356+
357+
// initialize the starting position for each sequence based on the positions in the memory
358+
llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
359+
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
360+
if (!memory) {
361+
p0[s] = 0;
362+
} else {
363+
p0[s] = memory->seq_pos_max(s) + 1;
364+
}
365+
}
366+
367+
for (int32_t i = 0; i < batch.n_tokens; i++) {
368+
const llama_seq_id seq_id = batch.seq_id[i][0];
369+
pos[i] = p0[seq_id] + i;
370+
}
371+
372+
batch.pos = pos.data();
373+
}
374+
367375
if (!batch.logits) {
368376
// by default return the output only for the last token
369377
output.resize(batch.n_tokens);
@@ -379,24 +387,54 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
379387
n_outputs += batch.logits[i] != 0;
380388
}
381389

390+
// determine coupled sequences
391+
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
382392
for (int32_t i = 0; i < batch.n_tokens; ++i) {
383393
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
384394
seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
385395

386396
if (s > 0) {
387-
seq_cpl[batch.seq_id[i][0]][batch.seq_id[i][s]] = true;
397+
const llama_seq_id s0 = batch.seq_id[i][0];
398+
const llama_seq_id s1 = batch.seq_id[i][s];
399+
400+
seq_cpl[s1][s0] = true;
388401
}
389402
}
390403
}
391404

392-
// TODO:
393-
// - verify that coupled sequences have same "position contexts"
394-
// - verify that input sequences are "contiguous" (no position gaps)
395-
// - verify that input sequences begin from the last poition currently in the context
405+
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
406+
if (seq_pos[s].empty()) {
407+
continue;
408+
}
409+
410+
if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
411+
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
412+
return false;
413+
}
414+
415+
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
416+
LLAMA_LOG_ERROR("%s: sequence %d is not contiguous\n", __func__, s);
417+
return false;
418+
}
419+
}
420+
421+
if (memory) {
422+
for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
423+
for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
424+
if (seq_cpl[s0][s1]) {
425+
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
426+
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
427+
LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
428+
return false;
429+
}
430+
}
431+
}
432+
}
433+
}
396434

397435
if (debug > 0) {
398-
LLAMA_LOG_DEBUG("%s: input batch info (p0 = %d):\n", __func__, p0);
399-
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
436+
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
437+
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
400438
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
401439
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
402440
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
@@ -439,14 +477,21 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
439477
}
440478
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
441479

442-
LLAMA_LOG_DEBUG("%s: seq_pos = [\n", __func__);
443-
for (int s = 0; s < (int) seq_pos.size(); ++s) {
444-
const auto & cur = seq_pos[s];
445-
if (cur.empty()) {
480+
LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
481+
for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
482+
if (seq_pos[s0].empty()) {
446483
continue;
447484
}
448485

449-
LLAMA_LOG_DEBUG("%s: %4d: [%4d, %4d]\n", __func__, s, seq_pos_min(s), seq_pos_max(s));
486+
std::stringstream ss;
487+
for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
488+
if (seq_cpl[s0][s1]) {
489+
ss << s1 << " ";
490+
}
491+
}
492+
493+
LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
494+
__func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
450495
}
451496
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
452497
}
@@ -485,7 +530,7 @@ void llama_batch_allocr::clear() {
485530
}
486531

487532
for (auto & cur : seq_cpl) {
488-
cur.clear();
533+
std::fill(cur.begin(), cur.end(), false);
489534
}
490535
}
491536

src/llama-batch.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,10 @@ class llama_batch_allocr {
8484
llama_batch_allocr();
8585

8686
// optionally fulfill the batch returned by llama_batch_get_one
87-
// TODO: extend p0 to be per-sequence: provide `seq_pos_min` and `seq_pos_max` from the memory
88-
bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0);
87+
bool init(
88+
const llama_batch & batch_inp,
89+
const llama_vocab & vocab,
90+
const llama_memory_i * memory);
8991

9092
const llama_batch & get_batch() const;
9193

@@ -109,7 +111,7 @@ class llama_batch_allocr {
109111
std::vector<int8_t> output;
110112

111113
std::vector<std::set<llama_pos>> seq_pos; // the positions of each sequence
112-
std::vector<std::vector<bool>> seq_cpl; // if sequences i and j are coupled
114+
std::vector<std::vector<bool>> seq_cpl; // if sequences i is coupled to sequence j
113115

114116
int debug;
115117
};

src/llama-context.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
729729

730730
// temporary allocate memory for the input batch if needed
731731
// note: during encode, we always pass the full sequence starting from pos = 0
732-
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : 0)) {
732+
if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
733733
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
734734
return -1;
735735
}
@@ -896,7 +896,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
896896
}
897897

898898
// temporary allocate memory for the input batch if needed
899-
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : memory->seq_pos_max(0) + 1)) {
899+
if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
900900
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
901901
return -1;
902902
}

src/llama-cparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <cstdint>
66

7+
// TODO: rename to something shorter
78
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
89

910
struct llama_cparams {

0 commit comments

Comments
 (0)