Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 67 additions & 61 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5190,6 +5190,57 @@ struct llama_model_loader {
}
};

// temporary allocate memory for the input batch if needed
static const llama_seq_id batch_default_seq_id = 0;
struct llama_batch_allocr {
std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> logits;
struct llama_batch batch;
// optionally fulfill the batch returned by llama_batch_get_one
llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) {
batch = in_batch;
GGML_ASSERT(batch.n_tokens > 0);
if (!batch.pos) {
// determine the last position in KV cache
llama_pos last_pos = -1;
for (const auto & cell : ctx.kv_self.cells) {
if (cell.has_seq_id(batch_default_seq_id)) {
last_pos = std::max(last_pos, cell.pos);
}
}
last_pos++; // next position
pos.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
pos[i] = i+last_pos;
}
batch.pos = pos.data();
}
if (!batch.n_seq_id) {
n_seq_id.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
n_seq_id[i] = seq_id_0.size();
}
batch.n_seq_id = n_seq_id.data();
}
if (!batch.seq_id) {
seq_id.resize(batch.n_tokens + 1);
seq_id[batch.n_tokens] = NULL;
for (int32_t i = 0; i < batch.n_tokens; i++) {
seq_id[i] = seq_id_0.data();
}
batch.seq_id = seq_id.data();
}
if (!batch.logits) {
logits.resize(batch.n_tokens);
logits[logits.size() - 1] = true;
batch.logits = logits.data();
}
}
};

template<>
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
uint32_t tmp;
Expand Down Expand Up @@ -17108,16 +17159,20 @@ static void llama_graph_compute(
//
static int llama_decode_internal(
llama_context & lctx,
llama_batch batch) {
llama_batch inp_batch) {

lctx.is_encoding = false;
const uint32_t n_tokens_all = batch.n_tokens;

if (n_tokens_all == 0) {
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}

// temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(lctx, inp_batch);
const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens_all = batch.n_tokens;

const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
Expand Down Expand Up @@ -17422,17 +17477,20 @@ static int llama_decode_internal(
//
static int llama_encode_internal(
llama_context & lctx,
llama_batch batch) {
llama_batch inp_batch) {

lctx.is_encoding = true;

const uint32_t n_tokens = batch.n_tokens;

if (n_tokens == 0) {
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}

// temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(lctx, inp_batch);
const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens = batch.n_tokens;

const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
Expand Down Expand Up @@ -21127,61 +21185,10 @@ void llama_batch_free(struct llama_batch batch) {
if (batch.logits) free(batch.logits);
}

// temporary allocate memory for the input batch if needed
static const llama_seq_id batch_default_seq_id = 0;
struct llama_batch_allocr {
std::array<llama_seq_id, 1> seq_id_0 = {batch_default_seq_id};
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> logits;
struct llama_batch batch;
// optionally fulfill the batch returned by llama_batch_get_one
llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) {
batch = in_batch;
if (!batch.pos) {
// determine the last position in KV cache
llama_pos last_pos = -1;
for (const auto & cell : ctx->kv_self.cells) {
if (cell.has_seq_id(batch_default_seq_id)) {
last_pos = std::max(last_pos, cell.pos);
}
}
last_pos++; // next position
pos.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
pos[i] = i+last_pos;
}
batch.pos = pos.data();
}
if (!batch.n_seq_id) {
n_seq_id.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
n_seq_id[i] = seq_id_0.size();
}
batch.n_seq_id = n_seq_id.data();
}
if (!batch.seq_id) {
seq_id.resize(batch.n_tokens + 1);
seq_id[batch.n_tokens] = NULL;
for (int32_t i = 0; i < batch.n_tokens; i++) {
seq_id[i] = seq_id_0.data();
}
batch.seq_id = seq_id.data();
}
if (!batch.logits) {
logits.resize(batch.n_tokens);
logits[logits.size() - 1] = true;
batch.logits = logits.data();
}
}
};

int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch) {
llama_batch_allocr batch_allocr(ctx, batch);
const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
const int ret = llama_encode_internal(*ctx, batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
}
Expand All @@ -21192,8 +21199,7 @@ int32_t llama_encode(
int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch) {
llama_batch_allocr batch_allocr(ctx, batch);
const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
const int ret = llama_decode_internal(*ctx, batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}
Expand Down
Loading