11#include " llama-batch.h"
22
3+ #include " llama-impl.h"
4+ #include " llama-cparams.h"
5+ #include " llama-vocab.h"
6+
37#include < cassert>
48#include < cstring>
59#include < algorithm>
@@ -279,9 +283,42 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
279283 );
280284}
281285
282- llama_batch_allocr::llama_batch_allocr (struct llama_batch in_batch, llama_pos p0) {
283- batch = in_batch;
286+ llama_batch_allocr::llama_batch_allocr () = default;
287+
288+ bool llama_batch_allocr::init (const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
289+ clear ();
290+
291+ batch = batch_inp;
292+
284293 GGML_ASSERT (batch.n_tokens > 0 );
294+
295+ if (!batch.pos ) {
296+ if (batch.seq_id ) {
297+ LLAMA_LOG_ERROR (" %s: pos == NULL, but seq_id != NULL\n " , __func__);
298+ return false ;
299+ }
300+ }
301+
302+ if (batch.token ) {
303+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
304+ if (batch.token [i] < 0 || (uint32_t ) batch.token [i] >= vocab.n_tokens ()) {
305+ LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
306+ return false ;
307+ }
308+ }
309+ }
310+
311+ if (batch.seq_id ) {
312+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
313+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
314+ if (batch.seq_id && (batch.seq_id [i][s] < 0 || batch.seq_id [i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
315+ LLAMA_LOG_ERROR (" %s: invalid seq_id[%d][%d] = %d > %d\n " , __func__, i, s, batch.seq_id [i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
316+ return false ;
317+ }
318+ }
319+ }
320+ }
321+
285322 if (!batch.pos ) {
286323 assert (p0 >= 0 );
287324 pos.resize (batch.n_tokens );
@@ -290,13 +327,15 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
290327 }
291328 batch.pos = pos.data ();
292329 }
330+
293331 if (!batch.n_seq_id ) {
294332 n_seq_id.resize (batch.n_tokens );
295333 for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
296334 n_seq_id[i] = seq_id_0.size ();
297335 }
298336 batch.n_seq_id = n_seq_id.data ();
299337 }
338+
300339 if (!batch.seq_id ) {
301340 seq_id.resize (batch.n_tokens + 1 );
302341 seq_id[batch.n_tokens ] = NULL ;
@@ -305,12 +344,37 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
305344 }
306345 batch.seq_id = seq_id.data ();
307346 }
347+
308348 if (!batch.logits ) {
309349 // by default return the output only for the last token
310350 output.resize (batch.n_tokens );
311351 output[output.size () - 1 ] = true ;
312352 batch.logits = output.data ();
313353 }
354+
355+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
356+ n_outputs += batch.logits [i] != 0 ;
357+ }
358+
359+ return true ;
360+ }
361+
362+ const llama_batch & llama_batch_allocr::get_batch () const {
363+ return batch;
364+ }
365+
366+ uint32_t llama_batch_allocr::get_n_outputs () const {
367+ return n_outputs;
368+ }
369+
370+ void llama_batch_allocr::clear () {
371+ n_outputs = 0 ;
372+
373+ batch = {};
374+ pos.clear ();
375+ n_seq_id.clear ();
376+ seq_id.clear ();
377+ output.clear ();
314378}
315379
316380//
0 commit comments