@@ -622,8 +622,16 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
622622 );
623623}
624624
625- llama_batch_allocr::llama_batch_allocr (struct llama_batch in_batch, llama_pos p0) {
626- batch = in_batch;
625+ llama_batch_allocr::llama_batch_allocr () {
626+ const char * LLAMA_BATCH_DEBUG = getenv (" LLAMA_BATCH_DEBUG" );
627+ debug = LLAMA_BATCH_DEBUG ? atoi (LLAMA_BATCH_DEBUG) : 0 ;
628+ }
629+
630+ bool llama_batch_allocr::init (const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
631+ clear ();
632+
633+ batch = batch_inp;
634+
627635 GGML_ASSERT (batch.n_tokens > 0 );
628636
629637 if (!batch.pos ) {
@@ -690,6 +698,53 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
690698 n_outputs += batch.logits [i] != 0 ;
691699 }
692700
701+ if (debug > 0 ) {
702+ LLAMA_LOG_DEBUG (" %s: input batch info (p0 = %d):\n " , __func__, p0);
703+ LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, batch.n_tokens );
704+ LLAMA_LOG_DEBUG (" %s: token = %p\n " , __func__, (void *) batch.token );
705+ LLAMA_LOG_DEBUG (" %s: embd = %p\n " , __func__, (void *) batch.embd );
706+ LLAMA_LOG_DEBUG (" %s: pos = %p\n " , __func__, (void *) batch.pos );
707+ LLAMA_LOG_DEBUG (" %s: n_seq_id = %p\n " , __func__, (void *) batch.n_seq_id );
708+ LLAMA_LOG_DEBUG (" %s: seq_id = %p\n " , __func__, (void *) batch.seq_id );
709+ LLAMA_LOG_DEBUG (" %s: logits = %p\n " , __func__, (void *) batch.logits );
710+ LLAMA_LOG_DEBUG (" %s: n_outputs = %d\n " , __func__, n_outputs);
711+
712+ if (debug > 1 ) {
713+ int seq_id_max = 0 ;
714+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
715+ for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
716+ for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
717+ seq_id_max = std::max (seq_id_max, batch.seq_id [i][s]);
718+ }
719+ }
720+ }
721+ ++seq_id_max;
722+
723+ LLAMA_LOG_DEBUG (" %s: token = [\n " , __func__);
724+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
725+ std::vector<int8_t > seq_id (seq_id_max);
726+
727+ for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
728+ seq_id[batch.seq_id [i][s]] = 1 ;
729+ }
730+
731+ std::stringstream ss;
732+ for (int s = 0 ; s < seq_id_max; ++s) {
733+ if (seq_id[s]) {
734+ ss << s%10 ;
735+ } else {
736+ ss << " ." ;
737+ }
738+ }
739+
740+ LLAMA_LOG_DEBUG (" %s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n " ,
741+ __func__, i, batch.token [i], vocab.token_to_piece (batch.token [i]).c_str (),
742+ batch.pos [i], batch.n_seq_id [i], ss.str ().c_str (), batch.logits [i]);
743+ }
744+ LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
745+ }
746+ }
747+
693748 return true ;
694749}
695750
0 commit comments