Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,14 @@ extern "C" {

typedef bool (*llama_progress_callback)(float progress, void * user_data);

// Input data for llama_decode
// Input data for llama_encode/llama_decode
// A llama_batch object can contain input about one or many sequences
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
//
// - token : the token ids of the input (used when embd is NULL)
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence
// (if set to NULL, the token position will be tracked automatically by llama_decode)
// (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode)
// - seq_id : the sequence to which the respective token belongs
// (if set to NULL, the sequence ID will be assumed to be 0)
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
Expand Down
153 changes: 136 additions & 17 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "llama-impl.h"
#include "llama-cparams.h"
#include "llama-vocab.h"
#include "llama-memory.h"

#include <cassert>
#include <cstring>
Expand Down Expand Up @@ -287,21 +288,27 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
llama_batch_allocr::llama_batch_allocr() {
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;

seq_pos.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
seq_cpl.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
for (auto & cur : seq_cpl) {
cur.resize(LLAMA_MAX_PARALLEL_SEQUENCES);
}
}

bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
bool llama_batch_allocr::init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory) {
clear();

batch = batch_inp;

GGML_ASSERT(batch.n_tokens > 0);

if (!batch.pos) {
if (batch.seq_id) {
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
return false;
}
}
//
// validate input batch
//

if (batch.token) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
Expand All @@ -323,14 +330,9 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
}
}

if (!batch.pos) {
assert(p0 >= 0);
pos.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
pos[i] = p0 + i;
}
batch.pos = pos.data();
}
//
// auto-generate missing fields
//

if (!batch.n_seq_id) {
n_seq_id.resize(batch.n_tokens);
Expand All @@ -349,20 +351,69 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
batch.seq_id = seq_id.data();
}

if (!batch.pos) {
pos.resize(batch.n_tokens);

// initialize the starting position for each sequence based on the positions in the memory
llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
if (!memory) {
p0[s] = 0;
} else {
p0[s] = memory->seq_pos_max(s) + 1;
}
}

for (int32_t i = 0; i < batch.n_tokens; i++) {
const llama_seq_id seq_id = batch.seq_id[i][0];

pos[i] = p0[seq_id];

for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
p0[batch.seq_id[i][s]] = pos[i] + 1;
}
}

batch.pos = pos.data();
}

if (!batch.logits) {
// by default return the output only for the last token
output.resize(batch.n_tokens);
output[output.size() - 1] = true;
batch.logits = output.data();
}

//
// compute stats
//

for (int32_t i = 0; i < batch.n_tokens; ++i) {
n_outputs += batch.logits[i] != 0;
}

// determine coupled sequences
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
for (int32_t i = 0; i < batch.n_tokens; ++i) {
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);

if (s > 0) {
const llama_seq_id s0 = batch.seq_id[i][0];
const llama_seq_id s1 = batch.seq_id[i][s];

// mark that sequence s1 is coupled to s0
seq_cpl[s1][s0] = true;

// note: the other way around is not necessary for now
//seq_cpl[s0][s1] = true;
}
}
}

if (debug > 0) {
LLAMA_LOG_DEBUG("%s: input batch info (p0 = %d):\n", __func__, p0);
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
Expand Down Expand Up @@ -404,6 +455,58 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
}
LLAMA_LOG_DEBUG("%s: ]\n", __func__);

LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
if (seq_pos[s0].empty()) {
continue;
}

std::stringstream ss;
for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
if (seq_cpl[s0][s1]) {
ss << s1 << " ";
}
}

LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
__func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
}
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
}
}

//
// consistency checks
//

for (int32_t s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
if (seq_pos[s].empty()) {
continue;
}

if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
return false;
}

if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
return false;
}
}

if (memory) {
for (int32_t s0 = 0; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
for (int32_t s1 = 0; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
if (seq_cpl[s0][s1]) {
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
return false;
}
}
}
}
}

Expand All @@ -418,6 +521,14 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
return n_outputs;
}

llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
}

llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
}

void llama_batch_allocr::clear() {
n_outputs = 0;

Expand All @@ -426,6 +537,14 @@ void llama_batch_allocr::clear() {
n_seq_id.clear();
seq_id.clear();
output.clear();

for (auto & cur : seq_pos) {
cur.clear();
}

for (auto & cur : seq_cpl) {
std::fill(cur.begin(), cur.end(), false);
}
}

//
Expand Down
17 changes: 14 additions & 3 deletions src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <array>
#include <vector>
#include <set>

// very similar to llama_batch,
// but has more metadata about sequences
Expand Down Expand Up @@ -77,18 +78,25 @@ struct llama_sbatch {
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
};

// temporary allocate memory for the input batch if needed
// a helper for sanitizing and fulfilling a batch
class llama_batch_allocr {
public:
llama_batch_allocr();

// optionally fulfill the batch returned by llama_batch_get_one
bool init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0);
// sanitize and auto-gen missing data in the input batch
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
bool init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory);

const llama_batch & get_batch() const;

uint32_t get_n_outputs() const;

llama_pos seq_pos_min(llama_seq_id seq_id) const;
llama_pos seq_pos_max(llama_seq_id seq_id) const;

private:
void clear();

Expand All @@ -103,5 +111,8 @@ class llama_batch_allocr {
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> output;

std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1

int debug;
};
6 changes: 2 additions & 4 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,9 +727,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
return -1;
}

// temporary allocate memory for the input batch if needed
// note: during encode, we always pass the full sequence starting from pos = 0
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : 0)) {
if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
Expand Down Expand Up @@ -895,8 +894,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
return -1;
}

// temporary allocate memory for the input batch if needed
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : memory->seq_pos_max(0) + 1)) {
if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
Expand Down
1 change: 1 addition & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <cstdint>

// TODO: rename to something shorter
#define LLAMA_MAX_PARALLEL_SEQUENCES 64

struct llama_cparams {
Expand Down