| 
7 | 7 | #include <cassert>  | 
8 | 8 | #include <cstring>  | 
9 | 9 | #include <algorithm>  | 
 | 10 | +#include <sstream>  | 
10 | 11 | 
 
  | 
11 | 12 | llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {  | 
12 | 13 |     // clear empty sequences  | 
@@ -283,7 +284,10 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple  | 
283 | 284 |             );  | 
284 | 285 | }  | 
285 | 286 | 
 
  | 
286 |  | -llama_batch_allocr::llama_batch_allocr() = default;  | 
 | 287 | +llama_batch_allocr::llama_batch_allocr() {  | 
 | 288 | +    const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");  | 
 | 289 | +    debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;  | 
 | 290 | +}  | 
287 | 291 | 
 
  | 
288 | 292 | bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {  | 
289 | 293 |     clear();  | 
@@ -356,6 +360,53 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &  | 
356 | 360 |         n_outputs += batch.logits[i] != 0;  | 
357 | 361 |     }  | 
358 | 362 | 
 
  | 
 | 363 | +    if (debug > 0) {  | 
 | 364 | +        LLAMA_LOG_DEBUG("%s: input batch info (p0 = %d):\n", __func__, p0);  | 
 | 365 | +        LLAMA_LOG_DEBUG("%s:   n_tokens  = %d\n", __func__, batch.n_tokens);  | 
 | 366 | +        LLAMA_LOG_DEBUG("%s:   token     = %p\n", __func__, (void *) batch.token);  | 
 | 367 | +        LLAMA_LOG_DEBUG("%s:   embd      = %p\n", __func__, (void *) batch.embd);  | 
 | 368 | +        LLAMA_LOG_DEBUG("%s:   pos       = %p\n", __func__, (void *) batch.pos);  | 
 | 369 | +        LLAMA_LOG_DEBUG("%s:   n_seq_id  = %p\n", __func__, (void *) batch.n_seq_id);  | 
 | 370 | +        LLAMA_LOG_DEBUG("%s:   seq_id    = %p\n", __func__, (void *) batch.seq_id);  | 
 | 371 | +        LLAMA_LOG_DEBUG("%s:   logits    = %p\n", __func__, (void *) batch.logits);  | 
 | 372 | +        LLAMA_LOG_DEBUG("%s:   n_outputs = %d\n", __func__, n_outputs);  | 
 | 373 | + | 
 | 374 | +        if (debug > 1) {  | 
 | 375 | +            int seq_id_max = 0;  | 
 | 376 | +            for (int32_t i = 0; i < batch.n_tokens; ++i) {  | 
 | 377 | +                for (int s = 0; s < batch.n_seq_id[i]; ++s) {  | 
 | 378 | +                    for (int s = 0; s < batch.n_seq_id[i]; ++s) {  | 
 | 379 | +                        seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]);  | 
 | 380 | +                    }  | 
 | 381 | +                }  | 
 | 382 | +            }  | 
 | 383 | +            ++seq_id_max;  | 
 | 384 | + | 
 | 385 | +            LLAMA_LOG_DEBUG("%s:   token     = [\n", __func__);  | 
 | 386 | +            for (int32_t i = 0; i < batch.n_tokens; ++i) {  | 
 | 387 | +                std::vector<int8_t> seq_id(seq_id_max);  | 
 | 388 | + | 
 | 389 | +                for (int s = 0; s < batch.n_seq_id[i]; ++s) {  | 
 | 390 | +                    seq_id[batch.seq_id[i][s]] = 1;  | 
 | 391 | +                }  | 
 | 392 | + | 
 | 393 | +                std::stringstream ss;  | 
 | 394 | +                for (int s = 0; s < seq_id_max; ++s) {  | 
 | 395 | +                    if (seq_id[s]) {  | 
 | 396 | +                        ss << s%10;  | 
 | 397 | +                    } else {  | 
 | 398 | +                        ss << ".";  | 
 | 399 | +                    }  | 
 | 400 | +                }  | 
 | 401 | + | 
 | 402 | +                LLAMA_LOG_DEBUG("%s:  %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",  | 
 | 403 | +                        __func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(),  | 
 | 404 | +                        batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);  | 
 | 405 | +            }  | 
 | 406 | +            LLAMA_LOG_DEBUG("%s:   ]\n", __func__);  | 
 | 407 | +        }  | 
 | 408 | +    }  | 
 | 409 | + | 
359 | 410 |     return true;  | 
360 | 411 | }  | 
361 | 412 | 
 
  | 
 | 
0 commit comments