|
7 | 7 | #include <cassert>
|
8 | 8 | #include <cstring>
|
9 | 9 | #include <algorithm>
|
| 10 | +#include <sstream> |
10 | 11 |
|
11 | 12 | llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
|
12 | 13 | // clear empty sequences
|
@@ -283,7 +284,10 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
|
283 | 284 | );
|
284 | 285 | }
|
285 | 286 |
|
286 |
| -llama_batch_allocr::llama_batch_allocr() = default; |
| 287 | +llama_batch_allocr::llama_batch_allocr() { |
| 288 | + const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG"); |
| 289 | + debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0; |
| 290 | +} |
287 | 291 |
|
288 | 292 | bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
|
289 | 293 | clear();
|
@@ -356,6 +360,53 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
|
356 | 360 | n_outputs += batch.logits[i] != 0;
|
357 | 361 | }
|
358 | 362 |
|
| 363 | + if (debug > 0) { |
| 364 | + LLAMA_LOG_DEBUG("%s: input batch info (p0 = %d):\n", __func__, p0); |
| 365 | + LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens); |
| 366 | + LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token); |
| 367 | + LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd); |
| 368 | + LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos); |
| 369 | + LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id); |
| 370 | + LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id); |
| 371 | + LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits); |
| 372 | + LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs); |
| 373 | + |
| 374 | + if (debug > 1) { |
| 375 | + int seq_id_max = 0; |
| 376 | + for (int32_t i = 0; i < batch.n_tokens; ++i) { |
| 377 | + for (int s = 0; s < batch.n_seq_id[i]; ++s) { |
| 378 | + for (int s = 0; s < batch.n_seq_id[i]; ++s) { |
| 379 | + seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]); |
| 380 | + } |
| 381 | + } |
| 382 | + } |
| 383 | + ++seq_id_max; |
| 384 | + |
| 385 | + LLAMA_LOG_DEBUG("%s: token = [\n", __func__); |
| 386 | + for (int32_t i = 0; i < batch.n_tokens; ++i) { |
| 387 | + std::vector<int8_t> seq_id(seq_id_max); |
| 388 | + |
| 389 | + for (int s = 0; s < batch.n_seq_id[i]; ++s) { |
| 390 | + seq_id[batch.seq_id[i][s]] = 1; |
| 391 | + } |
| 392 | + |
| 393 | + std::stringstream ss; |
| 394 | + for (int s = 0; s < seq_id_max; ++s) { |
| 395 | + if (seq_id[s]) { |
| 396 | + ss << s%10; |
| 397 | + } else { |
| 398 | + ss << "."; |
| 399 | + } |
| 400 | + } |
| 401 | + |
| 402 | + LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n", |
| 403 | + __func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(), |
| 404 | + batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]); |
| 405 | + } |
| 406 | + LLAMA_LOG_DEBUG("%s: ]\n", __func__); |
| 407 | + } |
| 408 | + } |
| 409 | + |
359 | 410 | return true;
|
360 | 411 | }
|
361 | 412 |
|
|
0 commit comments