@@ -301,11 +301,10 @@ bool llama_batch_allocr::init(
301301 const llama_batch & batch_inp,
302302 const llama_vocab & vocab,
303303 const llama_memory_i * memory,
304- bool embd_all) {
304+ uint32_t n_embd,
305+ bool output_all) {
305306 clear ();
306307
307- split_reset ();
308-
309308 batch = batch_inp;
310309
311310 GGML_ASSERT (batch.n_tokens > 0 );
@@ -382,7 +381,7 @@ bool llama_batch_allocr::init(
382381 }
383382
384383 if (!batch.logits ) {
385- if (embd_all ) {
384+ if (output_all ) {
386385 // return the output for all tokens
387386 output.resize (batch.n_tokens , true );
388387 } else {
@@ -392,7 +391,7 @@ bool llama_batch_allocr::init(
392391 }
393392
394393 batch.logits = output.data ();
395- } else if (embd_all ) {
394+ } else if (output_all ) {
396395 bool warn = false ;
397396
398397 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
@@ -417,6 +416,8 @@ bool llama_batch_allocr::init(
417416 n_outputs += batch.logits [i] != 0 ;
418417 }
419418
419+ this ->n_embd = n_embd;
420+
420421 // determine coupled sequences
421422 // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
422423 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
@@ -572,6 +573,8 @@ bool llama_batch_allocr::init(
572573
573574 // TODO: check that positions are increasing
574575
576+ split_reset ();
577+
575578 return true ;
576579}
577580
@@ -580,7 +583,7 @@ const llama_batch & llama_batch_allocr::get_batch() const {
580583}
581584
582585uint32_t llama_batch_allocr::get_n_tokens () const {
583- return pos. size () ;
586+ return batch. n_tokens ;
584587}
585588
586589uint32_t llama_batch_allocr::get_n_outputs () const {
@@ -609,41 +612,20 @@ void llama_batch_allocr::split_reset() {
609612}
610613
611614llama_ubatch llama_batch_allocr::split_simple (uint32_t n_ubatch) {
612- llama_ubatch res {
613- /* .equal_seqs =*/ false ,
614- /* .n_tokens =*/ 0 ,
615- /* .n_seq_tokens =*/ 1 ,
616- /* .n_seqs =*/ 0 ,
617-
618- /* .token =*/ nullptr ,
619- /* .embd =*/ nullptr ,
620- /* .pos =*/ nullptr ,
621- /* .n_seq_id =*/ nullptr ,
622- /* .seq_id =*/ nullptr ,
623- /* .output =*/ nullptr
624- };
625-
626615 uint32_t cur_idx = 0 ;
627616 while (cur_idx < used.size () && used[cur_idx]) {
628617 ++cur_idx;
629618 }
630619
631620 if (cur_idx >= used.size ()) {
632- return res ;
621+ return {} ;
633622 }
634623
635624 std::vector<int32_t > idxs;
636625
637626 while (true ) {
638- res.n_tokens ++;
639- res.n_seqs ++;
640-
641627 idxs.push_back (cur_idx);
642628
643- if (output[cur_idx] != 0 ) {
644- out_ids.push_back (cur_idx);
645- }
646-
647629 used[cur_idx] = true ;
648630
649631 ++cur_idx;
@@ -652,31 +634,15 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
652634 break ;
653635 }
654636
655- if (res. n_tokens >= n_ubatch) {
637+ if (idxs. size () >= n_ubatch) {
656638 break ;
657639 }
658640 }
659641
660- add_ubatch (res, idxs);
661-
662- return res;
642+ return add_ubatch (idxs, idxs.size (), false );
663643}
664644
665645llama_ubatch llama_batch_allocr::split_equal (uint32_t n_ubatch) {
666- llama_ubatch res {
667- /* .equal_seqs =*/ true ,
668- /* .n_tokens =*/ 0 ,
669- /* .n_seq_tokens =*/ 0 ,
670- /* .n_seqs =*/ 0 ,
671-
672- /* .token =*/ nullptr ,
673- /* .embd =*/ nullptr ,
674- /* .pos =*/ nullptr ,
675- /* .n_seq_id =*/ nullptr ,
676- /* .seq_id =*/ nullptr ,
677- /* .output =*/ nullptr
678- };
679-
680646 std::vector<seq_set_t > cur_seq_set;
681647
682648 // determine the sequence sets participating in this ubatch
@@ -685,35 +651,45 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
685651 continue ;
686652 }
687653
688- for (size_t s = 0 ; s < cur_seq_set.size (); ++s) {
654+ bool add = true ;
655+
656+ for (uint32_t s = 0 ; s < cur_seq_set.size (); ++s) {
689657 // no overlap with existing sequence sets:
690- if ((cur_seq_set[s] & seq_set[i]).none ()) {
691- cur_seq_set.push_back (seq_set[i]);
658+ if (!(cur_seq_set[s] & seq_set[i]).none ()) {
659+ add = false ;
660+ break ;
661+ }
662+ }
692663
693- if (cur_seq_set.size () > (size_t ) n_ubatch) {
694- break ;
695- }
664+ if (add) {
665+ cur_seq_set.push_back (seq_set[i]);
666+
667+ if (cur_seq_set.size () > n_ubatch) {
668+ break ;
696669 }
697670 }
698671 }
699672
700- res.n_seqs = cur_seq_set.size ();
673+ const uint32_t n_seqs = cur_seq_set.size ();
674+
675+ if (n_seqs == 0 ) {
676+ return {};
677+ }
701678
702- std::vector<int32_t > cur_idx (cur_seq_set. size () , 0 );
679+ std::vector<int32_t > cur_idx (n_seqs , 0 );
703680
704- for (size_t s = 0 ; s < cur_seq_set. size () ; ++s) {
681+ for (uint32_t s = 0 ; s < n_seqs ; ++s) {
705682 while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
706683 ++cur_idx[s];
707684 }
708685 }
709686
710- std::vector<int32_t > idxs ;
687+ std::vector<idx_vec_t > idxs_per_seq (n_seqs) ;
711688
712- // TODO: reorder from 012301230123..., to 000...111...222...333...
713689 while (true ) {
714690 bool can_expand = true ;
715691
716- for (size_t s = 0 ; s < cur_seq_set. size () ; ++s) {
692+ for (uint32_t s = 0 ; s < n_seqs ; ++s) {
717693 if (cur_idx[s] >= (int32_t ) seq_set_map[cur_seq_set[s]].size ()) {
718694 can_expand = false ;
719695 break ;
@@ -724,71 +700,49 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
724700 break ;
725701 }
726702
727- res.n_tokens += res.n_seqs ;
728-
729- for (size_t s = 0 ; s < cur_seq_set.size (); ++s) {
703+ for (uint32_t s = 0 ; s < n_seqs; ++s) {
730704 const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
731- idxs.push_back (idx);
732-
733- if (output[idx] != 0 ) {
734- out_ids.push_back (idx);
735- }
705+ idxs_per_seq[s].push_back (idx);
736706
737707 used[idx] = true ;
738708
739709 ++cur_idx[s];
740710 }
741711
742- if (res. n_tokens + res. n_seqs > n_ubatch) {
712+ if ((idxs_per_seq[ 0 ]. size () + 1 )* n_seqs > n_ubatch) {
743713 break ;
744714 }
745715 }
746716
747- add_ubatch (res, idxs) ;
717+ std::vector< int32_t > idxs;
748718
749- return res;
719+ for (uint32_t s = 0 ; s < n_seqs; ++s) {
720+ idxs.insert (idxs.end (), idxs_per_seq[s].begin (), idxs_per_seq[s].end ());
721+ }
722+
723+ return add_ubatch (idxs, n_seqs, true );
750724}
751725
752726llama_ubatch llama_batch_allocr::split_seq (uint32_t n_ubatch) {
753- llama_ubatch res {
754- /* .equal_seqs =*/ true ,
755- /* .n_tokens =*/ 0 ,
756- /* .n_seq_tokens =*/ 0 ,
757- /* .n_seqs =*/ 1 ,
758-
759- /* .token =*/ nullptr ,
760- /* .embd =*/ nullptr ,
761- /* .pos =*/ nullptr ,
762- /* .n_seq_id =*/ nullptr ,
763- /* .seq_id =*/ nullptr ,
764- /* .output =*/ nullptr ,
765- };
766-
767727 uint32_t cur_idx = 0 ;
768728 while (cur_idx < used.size () && used[cur_idx]) {
769729 ++cur_idx;
770730 }
771731
772732 if (cur_idx >= used.size ()) {
773- return res ;
733+ return {} ;
774734 }
775735
776736 auto cur_seq_set = seq_set[cur_idx];
777737
778738 std::vector<int32_t > idxs;
779739
780740 while (true ) {
781- res.n_tokens ++;
782-
783741 idxs.push_back (cur_idx);
784742
785- if (output[cur_idx] != 0 ) {
786- out_ids.push_back (cur_idx);
787- }
788-
789743 used[cur_idx] = true ;
790744
791- if (res. n_tokens >= n_ubatch) {
745+ if (idxs. size () >= n_ubatch) {
792746 break ;
793747 }
794748
@@ -803,9 +757,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
803757 cur_seq_set = seq_set[cur_idx];
804758 }
805759
806- add_ubatch (res, idxs);
807-
808- return res;
760+ return add_ubatch (idxs, 1 , true );
809761}
810762
811763void llama_batch_allocr::clear () {
@@ -834,37 +786,60 @@ void llama_batch_allocr::clear() {
834786 seq_set_map.clear ();
835787}
836788
837- void llama_batch_allocr::add_ubatch (llama_ubatch & res, const std::vector<int32_t > & idxs) {
838- ubatches. emplace_back ();
789+ llama_ubatch llama_batch_allocr::add_ubatch (const std::vector<int32_t > & idxs, uint32_t n_seqs, bool equal_seqs ) {
790+ const uint32_t n_tokens = idxs. size ();
839791
840- auto & ubatch = ubatches. back ( );
792+ LLAMA_LOG_DEBUG ( " add_ubatch: n_tokens = %d, n_seqs = %d, equal_seqs = %d " , n_tokens, n_seqs, equal_seqs );
841793
842- assert (res. n_tokens == idxs. size () );
794+ assert (n_tokens%n_seqs == 0 );
843795
844- const auto n_tokens = res.n_tokens ;
796+ ubatches.emplace_back ();
797+
798+ auto & ubatch = ubatches.back ();
845799
846800 ubatch.token .resize (n_tokens);
847- // ubatch.embd.resize(0); // TODO
801+ ubatch.embd .resize (( int64_t ) n_tokens*n_embd);
848802 ubatch.pos .resize (n_tokens);
849803 ubatch.n_seq_id .resize (n_tokens);
850804 ubatch.seq_id .resize (n_tokens);
851805 ubatch.output .resize (n_tokens);
852806
853807 for (size_t i = 0 ; i < idxs.size (); ++i) {
854- ubatch.token [i] = batch.token [idxs[i]];
855- // ubatch.embd[i] = batch.embd[idxs[i]]; // TODO
808+ if (batch.token ) {
809+ ubatch.token [i] = batch.token [idxs[i]];
810+ }
811+
812+ if (batch.embd ) {
813+ memcpy (ubatch.embd .data () + i*n_embd, batch.embd + (int64_t ) idxs[i]*n_embd, n_embd*sizeof (float ));
814+ }
815+
856816 ubatch.pos [i] = batch.pos [idxs[i]];
857817 ubatch.n_seq_id [i] = batch.n_seq_id [idxs[i]];
858818 ubatch.seq_id [i] = batch.seq_id [idxs[i]];
859819 ubatch.output [i] = batch.logits [idxs[i]];
820+
821+ if (ubatch.output [i]) {
822+ out_ids.push_back (idxs[i]);
823+ }
860824 }
861825
862- res.token = ubatch.token .data ();
863- // res.embd = ubatch.embd.data(); // TODO
864- res.pos = ubatch.pos .data ();
865- res.n_seq_id = ubatch.n_seq_id .data ();
866- res.seq_id = ubatch.seq_id .data ();
867- res.output = ubatch.output .data ();
826+ llama_ubatch res {
827+ /* .equal_seqs =*/ equal_seqs,
828+ /* .n_tokens =*/ n_tokens,
829+ /* .n_seq_tokens =*/ n_tokens/n_seqs,
830+ /* .n_seqs =*/ n_seqs,
831+
832+ /* .token =*/ batch.token ? ubatch.token .data () : nullptr ,
833+ /* .embd =*/ batch.embd ? ubatch.embd .data () : nullptr ,
834+ /* .pos =*/ ubatch.pos .data (),
835+ /* .n_seq_id =*/ ubatch.n_seq_id .data (),
836+ /* .seq_id =*/ ubatch.seq_id .data (),
837+ /* .output =*/ ubatch.output .data (),
838+ };
839+
840+ LLAMA_LOG_DEBUG (" %s: added ubatch of size %d\n " , __func__, res.n_tokens );
841+
842+ return res;
868843}
869844
870845//
0 commit comments