Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cassert>
#include <cstring>
#include <algorithm>
#include <sstream>

llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
// clear empty sequences
Expand Down Expand Up @@ -283,7 +284,10 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
);
}

llama_batch_allocr::llama_batch_allocr() = default;
llama_batch_allocr::llama_batch_allocr() {
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
}

bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
clear();
Expand Down Expand Up @@ -356,6 +360,53 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
n_outputs += batch.logits[i] != 0;
}

if (debug > 0) {
LLAMA_LOG_DEBUG("%s: input batch info (p0 = %d):\n", __func__, p0);
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id);
LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id);
LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits);
LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);

if (debug > 1) {
int seq_id_max = 0;
for (int32_t i = 0; i < batch.n_tokens; ++i) {
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]);
}
}
}
++seq_id_max;

LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
for (int32_t i = 0; i < batch.n_tokens; ++i) {
std::vector<int8_t> seq_id(seq_id_max);

for (int s = 0; s < batch.n_seq_id[i]; ++s) {
seq_id[batch.seq_id[i][s]] = 1;
}

std::stringstream ss;
for (int s = 0; s < seq_id_max; ++s) {
if (seq_id[s]) {
ss << s%10;
} else {
ss << ".";
}
}

LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
__func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(),
batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
}
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
}
}

return true;
}

Expand Down
2 changes: 2 additions & 0 deletions src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,6 @@ class llama_batch_allocr {
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> output;

int debug;
};
Loading