33#include " llama-impl.h"
44#include " llama-cparams.h"
55#include " llama-vocab.h"
6+ #include " llama-memory.h"
67
78#include < cassert>
89#include < cstring>
@@ -287,21 +288,27 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
287288llama_batch_allocr::llama_batch_allocr () {
288289 const char * LLAMA_BATCH_DEBUG = getenv (" LLAMA_BATCH_DEBUG" );
289290 debug = LLAMA_BATCH_DEBUG ? atoi (LLAMA_BATCH_DEBUG) : 0 ;
291+
292+ seq_pos.resize (LLAMA_MAX_PARALLEL_SEQUENCES);
293+ seq_cpl.resize (LLAMA_MAX_PARALLEL_SEQUENCES);
294+ for (auto & cur : seq_cpl) {
295+ cur.resize (LLAMA_MAX_PARALLEL_SEQUENCES);
296+ }
290297}
291298
292- bool llama_batch_allocr::init (const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
299+ bool llama_batch_allocr::init (
300+ const llama_batch & batch_inp,
301+ const llama_vocab & vocab,
302+ const llama_memory_i * memory) {
293303 clear ();
294304
295305 batch = batch_inp;
296306
297307 GGML_ASSERT (batch.n_tokens > 0 );
298308
299- if (!batch.pos ) {
300- if (batch.seq_id ) {
301- LLAMA_LOG_ERROR (" %s: pos == NULL, but seq_id != NULL\n " , __func__);
302- return false ;
303- }
304- }
309+ //
310+ // validate input batch
311+ //
305312
306313 if (batch.token ) {
307314 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
@@ -323,14 +330,9 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
323330 }
324331 }
325332
326- if (!batch.pos ) {
327- assert (p0 >= 0 );
328- pos.resize (batch.n_tokens );
329- for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
330- pos[i] = p0 + i;
331- }
332- batch.pos = pos.data ();
333- }
333+ //
334+ // auto-generate missing fields
335+ //
334336
335337 if (!batch.n_seq_id ) {
336338 n_seq_id.resize (batch.n_tokens );
@@ -349,20 +351,69 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
349351 batch.seq_id = seq_id.data ();
350352 }
351353
354+ if (!batch.pos ) {
355+ pos.resize (batch.n_tokens );
356+
357+ // initialize the starting position for each sequence based on the positions in the memory
358+ llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
359+ for (int32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
360+ if (!memory) {
361+ p0[s] = 0 ;
362+ } else {
363+ p0[s] = memory->seq_pos_max (s) + 1 ;
364+ }
365+ }
366+
367+ for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
368+ const llama_seq_id seq_id = batch.seq_id [i][0 ];
369+
370+ pos[i] = p0[seq_id];
371+
372+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
373+ p0[batch.seq_id [i][s]] = pos[i] + 1 ;
374+ }
375+ }
376+
377+ batch.pos = pos.data ();
378+ }
379+
352380 if (!batch.logits ) {
353381 // by default return the output only for the last token
354382 output.resize (batch.n_tokens );
355383 output[output.size () - 1 ] = true ;
356384 batch.logits = output.data ();
357385 }
358386
387+ //
388+ // compute stats
389+ //
390+
359391 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
360392 n_outputs += batch.logits [i] != 0 ;
361393 }
362394
395+ // determine coupled sequences
396+ // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
397+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
398+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
399+ seq_pos[batch.seq_id [i][s]].insert (batch.pos [i]);
400+
401+ if (s > 0 ) {
402+ const llama_seq_id s0 = batch.seq_id [i][0 ];
403+ const llama_seq_id s1 = batch.seq_id [i][s];
404+
405+ // mark that sequence s1 is coupled to s0
406+ seq_cpl[s1][s0] = true ;
407+
408+ // note: the other way around is not necessary for now
409+ // seq_cpl[s0][s1] = true;
410+ }
411+ }
412+ }
413+
363414 if (debug > 0 ) {
364- LLAMA_LOG_DEBUG (" %s: input batch info (p0 = %d) :\n " , __func__, p0 );
365- LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, batch.n_tokens );
415+ LLAMA_LOG_DEBUG (" %s: input batch info:\n " , __func__);
416+ LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, batch.n_tokens );
366417 LLAMA_LOG_DEBUG (" %s: token = %p\n " , __func__, (void *) batch.token );
367418 LLAMA_LOG_DEBUG (" %s: embd = %p\n " , __func__, (void *) batch.embd );
368419 LLAMA_LOG_DEBUG (" %s: pos = %p\n " , __func__, (void *) batch.pos );
@@ -404,6 +455,58 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
404455 batch.pos [i], batch.n_seq_id [i], ss.str ().c_str (), batch.logits [i]);
405456 }
406457 LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
458+
459+ LLAMA_LOG_DEBUG (" %s: seq = [\n " , __func__);
460+ for (int s0 = 0 ; s0 < (int ) seq_pos.size (); ++s0) {
461+ if (seq_pos[s0].empty ()) {
462+ continue ;
463+ }
464+
465+ std::stringstream ss;
466+ for (int s1 = 0 ; s1 < (int ) seq_cpl[s0].size (); ++s1) {
467+ if (seq_cpl[s0][s1]) {
468+ ss << s1 << " " ;
469+ }
470+ }
471+
472+ LLAMA_LOG_DEBUG (" %s: %4d: pos = [%4d, %4d], cpl = %s\n " ,
473+ __func__, s0, seq_pos_min (s0), seq_pos_max (s0), ss.str ().empty () ? " -" : ss.str ().c_str ());
474+ }
475+ LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
476+ }
477+ }
478+
479+ //
480+ // consistency checks
481+ //
482+
483+ for (int32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
484+ if (seq_pos[s].empty ()) {
485+ continue ;
486+ }
487+
488+ if (memory && seq_pos_min (s) != memory->seq_pos_max (s) + 1 ) {
489+ LLAMA_LOG_ERROR (" %s: sequence %d does not start from the last position stored in the memory\n " , __func__, s);
490+ return false ;
491+ }
492+
493+ if (seq_pos_max (s) - seq_pos_min (s) + 1 > (int ) seq_pos[s].size ()) {
494+ LLAMA_LOG_ERROR (" %s: sequence %d positions are not continuous\n " , __func__, s);
495+ return false ;
496+ }
497+ }
498+
499+ if (memory) {
500+ for (int32_t s0 = 0 ; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
501+ for (int32_t s1 = 0 ; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
502+ if (seq_cpl[s0][s1]) {
503+ if (memory->seq_pos_min (s0) != memory->seq_pos_min (s1) ||
504+ memory->seq_pos_max (s0) != memory->seq_pos_max (s1)) {
505+ LLAMA_LOG_ERROR (" %s: sequence %d is coupled to %d in the input batch, but have divereged\n " , __func__, s0, s1);
506+ return false ;
507+ }
508+ }
509+ }
407510 }
408511 }
409512
@@ -418,6 +521,14 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
418521 return n_outputs;
419522}
420523
524+ llama_pos llama_batch_allocr::seq_pos_min (llama_seq_id seq_id) const {
525+ return seq_pos[seq_id].empty () ? -1 : *seq_pos[seq_id].begin ();
526+ }
527+
528+ llama_pos llama_batch_allocr::seq_pos_max (llama_seq_id seq_id) const {
529+ return seq_pos[seq_id].empty () ? -1 : *seq_pos[seq_id].rbegin ();
530+ }
531+
421532void llama_batch_allocr::clear () {
422533 n_outputs = 0 ;
423534
@@ -426,6 +537,14 @@ void llama_batch_allocr::clear() {
426537 n_seq_id.clear ();
427538 seq_id.clear ();
428539 output.clear ();
540+
541+ for (auto & cur : seq_pos) {
542+ cur.clear ();
543+ }
544+
545+ for (auto & cur : seq_cpl) {
546+ std::fill (cur.begin (), cur.end (), false );
547+ }
429548}
430549
431550//
0 commit comments