33#include " llama-impl.h"
44#include " llama-cparams.h"
55#include " llama-vocab.h"
6+ #include " llama-memory.h"
67
78#include < cassert>
89#include < cstring>
@@ -625,21 +626,27 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
625626llama_batch_allocr::llama_batch_allocr () {
626627 const char * LLAMA_BATCH_DEBUG = getenv (" LLAMA_BATCH_DEBUG" );
627628 debug = LLAMA_BATCH_DEBUG ? atoi (LLAMA_BATCH_DEBUG) : 0 ;
629+
630+ seq_pos.resize (LLAMA_MAX_PARALLEL_SEQUENCES);
631+ seq_cpl.resize (LLAMA_MAX_PARALLEL_SEQUENCES);
632+ for (auto & cur : seq_cpl) {
633+ cur.resize (LLAMA_MAX_PARALLEL_SEQUENCES);
634+ }
628635}
629636
630- bool llama_batch_allocr::init (const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
637+ bool llama_batch_allocr::init (
638+ const llama_batch & batch_inp,
639+ const llama_vocab & vocab,
640+ const llama_memory_i * memory) {
631641 clear ();
632642
633643 batch = batch_inp;
634644
635645 GGML_ASSERT (batch.n_tokens > 0 );
636646
637- if (!batch.pos ) {
638- if (batch.seq_id ) {
639- LLAMA_LOG_ERROR (" %s: pos == NULL, but seq_id != NULL\n " , __func__);
640- return false ;
641- }
642- }
647+ //
648+ // validate input batch
649+ //
643650
644651 if (batch.token ) {
645652 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
@@ -661,14 +668,9 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
661668 }
662669 }
663670
664- if (!batch.pos ) {
665- assert (p0 >= 0 );
666- pos.resize (batch.n_tokens );
667- for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
668- pos[i] = p0 + i;
669- }
670- batch.pos = pos.data ();
671- }
671+ //
672+ // auto-generate missing fields
673+ //
672674
673675 if (!batch.n_seq_id ) {
674676 n_seq_id.resize (batch.n_tokens );
@@ -687,20 +689,69 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
687689 batch.seq_id = seq_id.data ();
688690 }
689691
692+ if (!batch.pos ) {
693+ pos.resize (batch.n_tokens );
694+
695+ // initialize the starting position for each sequence based on the positions in the memory
696+ llama_pos p0[LLAMA_MAX_PARALLEL_SEQUENCES];
697+ for (int32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
698+ if (!memory) {
699+ p0[s] = 0 ;
700+ } else {
701+ p0[s] = memory->seq_pos_max (s) + 1 ;
702+ }
703+ }
704+
705+ for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
706+ const llama_seq_id seq_id = batch.seq_id [i][0 ];
707+
708+ pos[i] = p0[seq_id];
709+
710+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
711+ p0[batch.seq_id [i][s]] = pos[i] + 1 ;
712+ }
713+ }
714+
715+ batch.pos = pos.data ();
716+ }
717+
690718 if (!batch.logits ) {
691719 // by default return the output only for the last token
692720 output.resize (batch.n_tokens );
693721 output[output.size () - 1 ] = true ;
694722 batch.logits = output.data ();
695723 }
696724
725+ //
726+ // compute stats
727+ //
728+
697729 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
698730 n_outputs += batch.logits [i] != 0 ;
699731 }
700732
733+ // determine coupled sequences
734+ // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
735+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
736+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
737+ seq_pos[batch.seq_id [i][s]].insert (batch.pos [i]);
738+
739+ if (s > 0 ) {
740+ const llama_seq_id s0 = batch.seq_id [i][0 ];
741+ const llama_seq_id s1 = batch.seq_id [i][s];
742+
743+ // mark that sequence s1 is coupled to s0
744+ seq_cpl[s1][s0] = true ;
745+
746+ // note: the other way around is not necessary for now
747+ // seq_cpl[s0][s1] = true;
748+ }
749+ }
750+ }
751+
701752 if (debug > 0 ) {
702- LLAMA_LOG_DEBUG (" %s: input batch info (p0 = %d) :\n " , __func__, p0 );
703- LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, batch.n_tokens );
753+ LLAMA_LOG_DEBUG (" %s: input batch info:\n " , __func__);
754+ LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, batch.n_tokens );
704755 LLAMA_LOG_DEBUG (" %s: token = %p\n " , __func__, (void *) batch.token );
705756 LLAMA_LOG_DEBUG (" %s: embd = %p\n " , __func__, (void *) batch.embd );
706757 LLAMA_LOG_DEBUG (" %s: pos = %p\n " , __func__, (void *) batch.pos );
@@ -742,6 +793,58 @@ bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab &
742793 batch.pos [i], batch.n_seq_id [i], ss.str ().c_str (), batch.logits [i]);
743794 }
744795 LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
796+
797+ LLAMA_LOG_DEBUG (" %s: seq = [\n " , __func__);
798+ for (int s0 = 0 ; s0 < (int ) seq_pos.size (); ++s0) {
799+ if (seq_pos[s0].empty ()) {
800+ continue ;
801+ }
802+
803+ std::stringstream ss;
804+ for (int s1 = 0 ; s1 < (int ) seq_cpl[s0].size (); ++s1) {
805+ if (seq_cpl[s0][s1]) {
806+ ss << s1 << " " ;
807+ }
808+ }
809+
810+ LLAMA_LOG_DEBUG (" %s: %4d: pos = [%4d, %4d], cpl = %s\n " ,
811+ __func__, s0, seq_pos_min (s0), seq_pos_max (s0), ss.str ().empty () ? " -" : ss.str ().c_str ());
812+ }
813+ LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
814+ }
815+ }
816+
817+ //
818+ // consistency checks
819+ //
820+
821+ for (int32_t s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
822+ if (seq_pos[s].empty ()) {
823+ continue ;
824+ }
825+
826+ if (memory && seq_pos_min (s) != memory->seq_pos_max (s) + 1 ) {
827+ LLAMA_LOG_ERROR (" %s: sequence %d does not start from the last position stored in the memory\n " , __func__, s);
828+ return false ;
829+ }
830+
831+ if (seq_pos_max (s) - seq_pos_min (s) + 1 > (int ) seq_pos[s].size ()) {
832+ LLAMA_LOG_ERROR (" %s: sequence %d positions are not continuous\n " , __func__, s);
833+ return false ;
834+ }
835+ }
836+
837+ if (memory) {
838+ for (int32_t s0 = 0 ; s0 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s0) {
839+ for (int32_t s1 = 0 ; s1 < LLAMA_MAX_PARALLEL_SEQUENCES; ++s1) {
840+ if (seq_cpl[s0][s1]) {
841+ if (memory->seq_pos_min (s0) != memory->seq_pos_min (s1) ||
842+ memory->seq_pos_max (s0) != memory->seq_pos_max (s1)) {
843+ LLAMA_LOG_ERROR (" %s: sequence %d is coupled to %d in the input batch, but have divereged\n " , __func__, s0, s1);
844+ return false ;
845+ }
846+ }
847+ }
745848 }
746849 }
747850
@@ -756,6 +859,14 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
756859 return n_outputs;
757860}
758861
862+ llama_pos llama_batch_allocr::seq_pos_min (llama_seq_id seq_id) const {
863+ return seq_pos[seq_id].empty () ? -1 : *seq_pos[seq_id].begin ();
864+ }
865+
866+ llama_pos llama_batch_allocr::seq_pos_max (llama_seq_id seq_id) const {
867+ return seq_pos[seq_id].empty () ? -1 : *seq_pos[seq_id].rbegin ();
868+ }
869+
759870void llama_batch_allocr::clear () {
760871 n_outputs = 0 ;
761872
@@ -764,6 +875,14 @@ void llama_batch_allocr::clear() {
764875 n_seq_id.clear ();
765876 seq_id.clear ();
766877 output.clear ();
878+
879+ for (auto & cur : seq_pos) {
880+ cur.clear ();
881+ }
882+
883+ for (auto & cur : seq_cpl) {
884+ std::fill (cur.begin (), cur.end (), false );
885+ }
767886}
768887
769888//
0 commit comments