File tree Expand file tree Collapse file tree 5 files changed +26
-5
lines changed Expand file tree Collapse file tree 5 files changed +26
-5
lines changed Original file line number Diff line number Diff line change @@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
166166
167167 // note: tracking the other way around is not necessary for now
168168 // seq_cpl[s0][s1] = true;
169+
170+ has_cpl = true ;
169171 }
170172 }
171173 }
@@ -472,9 +474,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
472474 return ubatch_add (idxs, idxs.size (), false );
473475}
474476
475- llama_ubatch llama_batch_allocr::split_equal (uint32_t n_ubatch) {
477+ llama_ubatch llama_batch_allocr::split_equal (uint32_t n_ubatch, bool sequential) {
478+ if (sequential && has_cpl) {
479+ LLAMA_LOG_ERROR (" %s: sequential split is not supported when there are coupled sequences in the input batch\n " , __func__);
480+
481+ return {};
482+ }
483+
476484 std::vector<seq_set_t > cur_seq_set;
477485
486+ llama_seq_id last_seq_id = -1 ;
487+
478488 // determine the non-overlapping sequence sets participating in this ubatch
479489 for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
480490 if (used[i]) {
@@ -491,9 +501,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
491501 }
492502 }
493503
504+ // accept only increasing sequence ids
505+ if (sequential) {
506+ add = add && (cur_seq_set.empty () || batch.seq_id [i][0 ] == last_seq_id + 1 );
507+ }
508+
494509 if (add) {
495510 cur_seq_set.push_back (seq_set[i]);
496511
512+ last_seq_id = batch.seq_id [i][0 ];
513+
497514 if (cur_seq_set.size () > n_ubatch) {
498515 break ;
499516 }
Original file line number Diff line number Diff line change @@ -72,7 +72,8 @@ class llama_batch_allocr {
7272 llama_ubatch split_simple (uint32_t n_ubatch);
7373
7474 // make ubatches of equal-length sequences sets
75- llama_ubatch split_equal (uint32_t n_ubatch);
75+ // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
76+ llama_ubatch split_equal (uint32_t n_ubatch, bool sequential);
7677
7778 // sequence-set-wise split - each ubatch contains a single sequence-set
7879 llama_ubatch split_seq (uint32_t n_ubatch);
@@ -115,6 +116,9 @@ class llama_batch_allocr {
115116 using pos_set_t = std::set<llama_pos>;
116117 using seq_cpl_t = std::vector<bool >;
117118
119+ // helper flag to quickly determine if there are any coupled sequences in the batch
120+ bool has_cpl;
121+
118122 std::vector<pos_set_t > seq_pos; // seq_pos[s]: the set of positions in sequence s
119123 std::vector<seq_cpl_t > seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
120124
Original file line number Diff line number Diff line change @@ -140,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
140140
141141 std::vector<llama_ubatch> ubatches;
142142 while (true ) {
143- auto ubatch = balloc.split_equal (n_ubatch);
143+ auto ubatch = balloc.split_equal (n_ubatch, false );
144144
145145 if (ubatch.n_tokens == 0 ) {
146146 break ;
Original file line number Diff line number Diff line change @@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
7070 // if all tokens are output, split by sequence
7171 ubatch = balloc.split_seq (n_ubatch);
7272 } else {
73- ubatch = balloc.split_equal (n_ubatch);
73+ ubatch = balloc.split_equal (n_ubatch, false );
7474 }
7575
7676 if (ubatch.n_tokens == 0 ) {
Original file line number Diff line number Diff line change @@ -374,7 +374,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
374374 // if all tokens are output, split by sequence
375375 ubatch = balloc.split_seq (n_ubatch);
376376 } else {
377- ubatch = balloc.split_equal (n_ubatch);
377+ ubatch = balloc.split_equal (n_ubatch, false );
378378 }
379379
380380 if (balloc.get_n_used () < balloc.get_n_tokens ()) {
You can’t perform that action at this time.
0 commit comments