Skip to content

Commit 4460d36

Browse files
ggerganovMinh141120
authored andcommitted
batch : auto-gen positions + verify multi-sequence input (ggml-org#14177)
* batch : verify multi-sequence input batches ggml-ci * cont : auto-gen positions + verify multi-seq input ggml-ci * cont : first print debug info, then perform validation ggml-ci * cont : fix position auto-gen + add comments ggml-ci
1 parent 649cd66 commit 4460d36

File tree

4 files changed

+153
-27
lines changed

4 files changed

+153
-27
lines changed

src/llama-batch.cpp

Lines changed: 136 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "llama-impl.h"
44
#include "llama-cparams.h"
55
#include "llama-vocab.h"
6+
#include "llama-memory.h"
67

78
#include <cassert>
89
#include <cstring>
@@ -625,21 +626,27 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
625626
llama_batch_allocr::llama_batch_allocr() {
626627
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
627628
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
629+
630+
seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
631+
seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
632+
for (auto & cur : seq_cpl) {
633+
cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
634+
}
628635
}
629636

630-
bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
637+
bool llama_batch_allocr::init(
638+
const llama_batch & batch_inp,
639+
const llama_vocab & vocab,
640+
const llama_memory_i * memory) {
631641
clear();
632642

633643
batch = batch_inp;
634644

635645
GGML_ASSERT(batch.n_tokens > 0);
636646

637-
if (!batch.pos) {
638-
if (batch.seq_id) {
639-
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
640-
return false;
641-
}
642-
}
647+
//
648+
// validate input batch
649+
//
643650

644651
if (batch.token) {
645652
for (int32_t i = 0; i < batch.n_tokens; ++i) {
@@ -661,14 +668,9 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
661668
}
662669
}
663670

664-
if (!batch.pos) {
665-
assert(p0 >= 0);
666-
pos.resize(batch.n_tokens);
667-
for (int32_t i = 0; i < batch.n_tokens; i++) {
668-
pos[i] = p0 + i;
669-
}
670-
batch.pos = pos.data();
671-
}
671+
//
672+
// auto-generate missing fields
673+
//
672674

673675
if (!batch.n_seq_id) {
674676
n_seq_id.resize(batch.n_tokens);
@@ -687,20 +689,69 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
687689
batch.seq_id = seq_id.data();
688690
}
689691

692+
if (!batch.pos) {
693+
pos.resize(batch.n_tokens);
694+
695+
// initialize the starting position for each sequence based on the positions in the memory
696+
llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
697+
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
698+
if (!memory) {
699+
p0[s] = 0;
700+
} else {
701+
p0[s] = memory->seq_pos_max(s) + 1;
702+
}
703+
}
704+
705+
for (int32_t i = 0; i < batch.n_tokens; i++) {
706+
const llama_seq_id seq_id = batch.seq_id[i][0];
707+
708+
pos[i] = p0[seq_id];
709+
710+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
711+
p0[batch.seq_id[i][s]] = pos[i] + 1;
712+
}
713+
}
714+
715+
batch.pos = pos.data();
716+
}
717+
690718
if (!batch.logits) {
691719
// by default return the output only for the last token
692720
output.resize(batch.n_tokens);
693721
output[output.size() - 1] = true;
694722
batch.logits = output.data();
695723
}
696724

725+
//
726+
// compute stats
727+
//
728+
697729
for (int32_t i = 0; i < batch.n_tokens; ++i) {
698730
n_outputs += batch.logits[i] != 0;
699731
}
700732

733+
// determine coupled sequences
734+
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
735+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
736+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
737+
seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
738+
739+
if (s > 0) {
740+
const llama_seq_id s0 = batch.seq_id[i][0];
741+
const llama_seq_id s1 = batch.seq_id[i][s];
742+
743+
// mark that sequence s1 is coupled to s0
744+
seq_cpl[s1][s0] = true;
745+
746+
// note: the other way around is not necessary for now
747+
//seq_cpl[s0][s1] = true;
748+
}
749+
}
750+
}
751+
701752
if (debug > 0) {
702-
LLAMA_LOG_DEBUG("%s: input batch info (p0 = %d):\n", __func__, p0);
703-
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
753+
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
754+
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
704755
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
705756
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
706757
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
@@ -742,6 +793,58 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
742793
batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
743794
}
744795
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
796+
797+
LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
798+
for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
799+
if (seq_pos[s0].empty()) {
800+
continue;
801+
}
802+
803+
std::stringstream ss;
804+
for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
805+
if (seq_cpl[s0][s1]) {
806+
ss << s1 << " ";
807+
}
808+
}
809+
810+
LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
811+
__func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
812+
}
813+
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
814+
}
815+
}
816+
817+
//
818+
// consistency checks
819+
//
820+
821+
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
822+
if (seq_pos[s].empty()) {
823+
continue;
824+
}
825+
826+
if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
827+
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
828+
return false;
829+
}
830+
831+
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
832+
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
833+
return false;
834+
}
835+
}
836+
837+
if (memory) {
838+
for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
839+
for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
840+
if (seq_cpl[s0][s1]) {
841+
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
842+
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
843+
LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
844+
return false;
845+
}
846+
}
847+
}
745848
}
746849
}
747850

@@ -756,6 +859,14 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
756859
return n_outputs;
757860
}
758861

862+
llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
863+
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
864+
}
865+
866+
llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
867+
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
868+
}
869+
759870
void llama_batch_allocr::clear() {
760871
n_outputs = 0;
761872

@@ -764,6 +875,14 @@ void llama_batch_allocr::clear() {
764875
n_seq_id.clear();
765876
seq_id.clear();
766877
output.clear();
878+
879+
for (auto & cur : seq_pos) {
880+
cur.clear();
881+
}
882+
883+
for (auto & cur : seq_cpl) {
884+
std::fill(cur.begin(), cur.end(), false);
885+
}
767886
}
768887

769888
//

src/llama-batch.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
#include <array>
88
#include <vector>
99
#include <set>
10-
#include <bitset>
11-
#include <unordered_map>
1210

1311
// keep this struct lightweight
1412
// it points to data in `llama_batch_allocr`
@@ -84,18 +82,25 @@ struct llama_sbatch {
8482
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
8583
};
8684

87-
// temporary allocate memory for the input batch if needed
85+
// a helper for sanitizing and fulfilling a batch
8886
class llama_batch_allocr {
8987
public:
9088
llama_batch_allocr();
9189

92-
// optionally fulfill the batch returned by llama_batch_get_one
93-
bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0);
90+
// sanitize and auto-gen missing data in the input batch
91+
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
92+
bool init(
93+
const llama_batch & batch_inp,
94+
const llama_vocab & vocab,
95+
const llama_memory_i * memory);
9496

9597
const llama_batch & get_batch() const;
9698

9799
uint32_t get_n_outputs() const;
98100

101+
llama_pos seq_pos_min(llama_seq_id seq_id) const;
102+
llama_pos seq_pos_max(llama_seq_id seq_id) const;
103+
99104
private:
100105
void clear();
101106

@@ -110,5 +115,8 @@ class llama_batch_allocr {
110115
std::vector<llama_seq_id *> seq_id;
111116
std::vector<int8_t> output;
112117

118+
std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
119+
std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
120+
113121
int debug;
114122
};

src/llama-context.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -727,9 +727,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
727727
return -1;
728728
}
729729

730-
// temporary allocate memory for the input batch if needed
731730
// note: during encode, we always pass the full sequence starting from pos = 0
732-
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : 0)) {
731+
if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
733732
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
734733
return -1;
735734
}
@@ -900,8 +899,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
900899
return -1;
901900
}
902901

903-
// temporary allocate memory for the input batch if needed
904-
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : memory->seq_pos_max(0) + 1)) {
902+
if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
905903
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
906904
return -1;
907905
}

src/llama-cparams.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
#include <cstdint>
66

7-
#define LLAMA_MAX_SEQ 64
7+
// TODO: rename to something shorter
8+
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
89

910
struct llama_cparams {
1011
uint32_t n_ctx; // context size used during inference

0 commit comments

Comments
 (0)