Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
ubatch.seq_id = batch->seq_id + seq.offset;
}
}
if (logits_all) {
for (size_t i = 0; i < length; ++i) {
ubatch.output[ubatch.n_tokens + i] = 1;
out_ids.push_back(ids[seq.offset + i]);
}
} else if (batch->logits) {
if (batch->logits) {
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
size_t id = ids[seq.offset + i];
Expand Down Expand Up @@ -197,11 +192,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
return ubatch;
}

llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
GGML_ASSERT(batch.n_tokens >= 0);
this->batch = &batch;
this->n_embd = n_embd;
this->logits_all = logits_all;

n_tokens = batch.n_tokens;
ids.resize(n_tokens);
Expand Down
4 changes: 1 addition & 3 deletions src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ struct llama_sbatch {

size_t n_embd;

bool logits_all; // TODO: remove once lctx.logits_all is removed too

// sorted indices into the batch
std::vector<int64_t> ids;
// batch indices of the output
Expand Down Expand Up @@ -76,7 +74,7 @@ struct llama_sbatch {
llama_ubatch split_seq(size_t n_ubatch);

llama_sbatch() = default;
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
};

// temporary allocate memory for the input batch if needed
Expand Down
6 changes: 3 additions & 3 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ int llama_context::encode(llama_batch & inp_batch) {

const int64_t n_embd = hparams.n_embd;

llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);

const llama_ubatch ubatch = sbatch.split_simple(n_tokens);

Expand Down Expand Up @@ -976,7 +976,7 @@ int llama_context::decode(llama_batch & inp_batch) {
llama_memory_state_ptr mstate;

while (true) {
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
if (!mstate) {
return -2;
}
Expand Down Expand Up @@ -2080,7 +2080,7 @@ void llama_context::opt_epoch_iter(

int64_t n_outputs_all = n_tokens_all;

auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break;
Expand Down
4 changes: 2 additions & 2 deletions src/llama-kv-cache-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,10 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result;
}

llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
GGML_UNUSED(embd_pooled);

auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);

std::vector<llama_ubatch> ubatches;

Expand Down
3 changes: 1 addition & 2 deletions src/llama-kv-cache-recurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ class llama_kv_cache_recurrent : public llama_memory_i {
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
bool embd_pooled) override;

llama_memory_state_ptr init_full() override;

Expand Down
6 changes: 3 additions & 3 deletions src/llama-kv-cache-unified-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
return kv_swa->seq_pos_max(seq_id);
}

llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
GGML_UNUSED(embd_pooled);

// first try simple split
do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);

std::vector<llama_ubatch> ubatches;

Expand Down Expand Up @@ -128,7 +128,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch

// if it fails, try equal split
do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);

std::vector<llama_ubatch> ubatches;

Expand Down
3 changes: 1 addition & 2 deletions src/llama-kv-cache-unified-iswa.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i {
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
bool embd_pooled) override;

llama_memory_state_ptr init_full() override;

Expand Down
5 changes: 2 additions & 3 deletions src/llama-kv-cache-unified.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,11 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) {
bool embd_pooled) {
GGML_UNUSED(embd_pooled);

do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);

std::vector<llama_ubatch> ubatches;
while (sbatch.n_tokens > 0) {
Expand Down
3 changes: 1 addition & 2 deletions src/llama-kv-cache-unified.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ class llama_kv_cache_unified : public llama_memory_i {
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
bool embd_pooled) override;

llama_memory_state_ptr init_full() override;

Expand Down
3 changes: 1 addition & 2 deletions src/llama-memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ struct llama_memory_i {
virtual llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) = 0;
bool embd_pooled) = 0;

// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;
Expand Down
Loading