11#include " llama-batch.h"
22
3+ #include " llama-impl.h"
4+ #include " llama-cparams.h"
5+ #include " llama-vocab.h"
6+
37#include < cassert>
48#include < cstring>
59#include < algorithm>
10+ #include < sstream>
611
712llama_ubatch llama_sbatch::reserve_ubatch (size_t n_ubatch, bool has_embd) {
813 // clear empty sequences
@@ -105,12 +110,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
105110 ubatch.seq_id = batch->seq_id + seq.offset ;
106111 }
107112 }
108- if (logits_all) {
109- for (size_t i = 0 ; i < length; ++i) {
110- ubatch.output [ubatch.n_tokens + i] = 1 ;
111- out_ids.push_back (ids[seq.offset + i]);
112- }
113- } else if (batch->logits ) {
113+ if (batch->logits ) {
114114 if (ubatch.equal_seqs ) {
115115 for (size_t i = 0 ; i < length; ++i) {
116116 size_t id = ids[seq.offset + i];
@@ -197,11 +197,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
197197 return ubatch;
198198}
199199
200- llama_sbatch::llama_sbatch (const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all ) {
200+ llama_sbatch::llama_sbatch (const llama_batch & batch, size_t n_embd, bool simple_split) {
201201 GGML_ASSERT (batch.n_tokens >= 0 );
202202 this ->batch = &batch;
203203 this ->n_embd = n_embd;
204- this ->logits_all = logits_all;
205204
206205 n_tokens = batch.n_tokens ;
207206 ids.resize (n_tokens);
@@ -285,9 +284,45 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
285284 );
286285}
287286
288- llama_batch_allocr::llama_batch_allocr (struct llama_batch in_batch, llama_pos p0) {
289- batch = in_batch;
287+ llama_batch_allocr::llama_batch_allocr () {
288+ const char * LLAMA_BATCH_DEBUG = getenv (" LLAMA_BATCH_DEBUG" );
289+ debug = LLAMA_BATCH_DEBUG ? atoi (LLAMA_BATCH_DEBUG) : 0 ;
290+ }
291+
292+ bool llama_batch_allocr::init (const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
293+ clear ();
294+
295+ batch = batch_inp;
296+
290297 GGML_ASSERT (batch.n_tokens > 0 );
298+
299+ if (!batch.pos ) {
300+ if (batch.seq_id ) {
301+ LLAMA_LOG_ERROR (" %s: pos == NULL, but seq_id != NULL\n " , __func__);
302+ return false ;
303+ }
304+ }
305+
306+ if (batch.token ) {
307+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
308+ if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= vocab.n_tokens ()) {
309+ LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
310+ return false ;
311+ }
312+ }
313+ }
314+
315+ if (batch.seq_id ) {
316+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
317+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
318+ if (batch.seq_id && (batch.seq_id [i][s] < 0 || batch.seq_id [i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
319+ LLAMA_LOG_ERROR (" %s: invalid seq_id[%d][%d] = %d > %d\n " , __func__, i, s, batch.seq_id [i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
320+ return false ;
321+ }
322+ }
323+ }
324+ }
325+
291326 if (!batch.pos ) {
292327 assert (p0 >= 0 );
293328 pos.resize (batch.n_tokens );
@@ -296,13 +331,15 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
296331 }
297332 batch.pos = pos.data ();
298333 }
334+
299335 if (!batch.n_seq_id ) {
300336 n_seq_id.resize (batch.n_tokens );
301337 for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
302338 n_seq_id[i] = seq_id_0.size ();
303339 }
304340 batch.n_seq_id = n_seq_id.data ();
305341 }
342+
306343 if (!batch.seq_id ) {
307344 seq_id.resize (batch.n_tokens + 1 );
308345 seq_id[batch.n_tokens ] = NULL ;
@@ -311,11 +348,84 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
311348 }
312349 batch.seq_id = seq_id.data ();
313350 }
351+
314352 if (!batch.logits ) {
315- logits.resize (batch.n_tokens );
316- logits[logits.size () - 1 ] = true ;
317- batch.logits = logits.data ();
353+ // by default return the output only for the last token
354+ output.resize (batch.n_tokens );
355+ output[output.size () - 1 ] = true ;
356+ batch.logits = output.data ();
318357 }
358+
359+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
360+ n_outputs += batch.logits [i] != 0 ;
361+ }
362+
363+ if (debug > 0 ) {
364+ LLAMA_LOG_DEBUG (" %s: input batch info (p0 = %d):\n " , __func__, p0);
365+ LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, batch.n_tokens );
366+ LLAMA_LOG_DEBUG (" %s: token = %p\n " , __func__, (void *) batch.token );
367+ LLAMA_LOG_DEBUG (" %s: embd = %p\n " , __func__, (void *) batch.embd );
368+ LLAMA_LOG_DEBUG (" %s: pos = %p\n " , __func__, (void *) batch.pos );
369+ LLAMA_LOG_DEBUG (" %s: n_seq_id = %p\n " , __func__, (void *) batch.n_seq_id );
370+ LLAMA_LOG_DEBUG (" %s: seq_id = %p\n " , __func__, (void *) batch.seq_id );
371+ LLAMA_LOG_DEBUG (" %s: logits = %p\n " , __func__, (void *) batch.logits );
372+ LLAMA_LOG_DEBUG (" %s: n_outputs = %d\n " , __func__, n_outputs);
373+
374+ if (debug > 1 ) {
375+ int seq_id_max = 0 ;
376+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
377+ for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
378+ for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
379+ seq_id_max = std::max (seq_id_max, batch.seq_id [i][s]);
380+ }
381+ }
382+ }
383+ ++seq_id_max;
384+
385+ LLAMA_LOG_DEBUG (" %s: token = [\n " , __func__);
386+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
387+ std::vector<int8_t > seq_id (seq_id_max);
388+
389+ for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
390+ seq_id[batch.seq_id [i][s]] = 1 ;
391+ }
392+
393+ std::stringstream ss;
394+ for (int s = 0 ; s < seq_id_max; ++s) {
395+ if (seq_id[s]) {
396+ ss << s%10 ;
397+ } else {
398+ ss << " ." ;
399+ }
400+ }
401+
402+ LLAMA_LOG_DEBUG (" %s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n " ,
403+ __func__, i, batch.token [i], vocab.token_to_piece (batch.token [i]).c_str (),
404+ batch.pos [i], batch.n_seq_id [i], ss.str ().c_str (), batch.logits [i]);
405+ }
406+ LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
407+ }
408+ }
409+
410+ return true ;
411+ }
412+
413+ const llama_batch & llama_batch_allocr::get_batch () const {
414+ return batch;
415+ }
416+
417+ uint32_t llama_batch_allocr::get_n_outputs () const {
418+ return n_outputs;
419+ }
420+
421+ void llama_batch_allocr::clear () {
422+ n_outputs = 0 ;
423+
424+ batch = {};
425+ pos.clear ();
426+ n_seq_id.clear ();
427+ seq_id.clear ();
428+ output.clear ();
319429}
320430
321431//
0 commit comments