File tree Expand file tree Collapse file tree 6 files changed +34
-1
lines changed Expand file tree Collapse file tree 6 files changed +34
-1
lines changed Original file line number Diff line number Diff line change @@ -405,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
405405 return n_outputs;
406406}
407407
408+ uint32_t llama_batch_allocr::get_n_used () const {
409+ return n_used;
410+ }
411+
408412std::vector<int32_t > & llama_batch_allocr::get_out_ids () {
409413 return out_ids;
410414}
@@ -420,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
420424void llama_batch_allocr::split_reset () {
421425 out_ids.clear ();
422426
427+ n_used = 0 ;
428+
423429 used.clear ();
424430 used.resize (get_n_tokens (), false );
425431
@@ -444,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
444450 idxs.push_back (cur_idx);
445451
446452 used[cur_idx] = true ;
453+ ++n_used;
447454
448455 ++cur_idx;
449456
@@ -529,6 +536,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
529536 idxs_per_seq[s].push_back (idx);
530537
531538 used[idx] = true ;
539+ ++n_used;
532540
533541 ++cur_idx[s];
534542 }
@@ -570,6 +578,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
570578 idxs.push_back (cur_idx);
571579
572580 used[cur_idx] = true ;
581+ ++n_used;
573582
574583 if (idxs.size () >= n_ubatch) {
575584 break ;
Original file line number Diff line number Diff line change @@ -54,6 +54,7 @@ class llama_batch_allocr {
5454
5555 uint32_t get_n_tokens () const ;
5656 uint32_t get_n_outputs () const ;
57+ uint32_t get_n_used () const ;
5758
5859 // the array of output indices in the order they were encountered during the ubatch splitting
5960 std::vector<int32_t > & get_out_ids ();
@@ -125,6 +126,8 @@ class llama_batch_allocr {
125126 // batch indices of the output
126127 std::vector<int32_t > out_ids;
127128
129+ uint32_t n_used;
130+
128131 // used[i] indicates if token i has already been used in a previous ubatch
129132 std::vector<bool > used;
130133
Original file line number Diff line number Diff line change @@ -113,6 +113,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
113113 ubatches.push_back (std::move (ubatch)); // NOLINT
114114 }
115115
116+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
117+ // failed to find a suitable split
118+ break ;
119+ }
120+
116121 auto sinfos_base = kv_base->prepare (ubatches);
117122 if (sinfos_base.empty ()) {
118123 break ;
@@ -144,6 +149,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
144149 ubatches.push_back (std::move (ubatch)); // NOLINT
145150 }
146151
152+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
153+ // failed to find a suitable split
154+ break ;
155+ }
156+
147157 auto sinfos_base = kv_base->prepare (ubatches);
148158 if (sinfos_base.empty ()) {
149159 break ;
Original file line number Diff line number Diff line change @@ -360,6 +360,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
360360 ubatches.push_back (std::move (ubatch)); // NOLINT
361361 }
362362
363+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
364+ // failed to find a suitable split
365+ break ;
366+ }
367+
363368 auto sinfos = prepare (ubatches);
364369 if (sinfos.empty ()) {
365370 break ;
Original file line number Diff line number Diff line change @@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
8080 ubatches.push_back (std::move (ubatch)); // NOLINT
8181 }
8282
83+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
84+ // failed to find a suitable split
85+ break ;
86+ }
87+
8388 // prepare the recurrent batches first
8489 if (!mem_recr->prepare (ubatches)) {
8590 // TODO: will the recurrent cache be in an undefined context at this point?
Original file line number Diff line number Diff line change @@ -377,7 +377,8 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
377377 ubatch = balloc.split_equal (n_ubatch);
378378 }
379379
380- if (ubatch.n_tokens == 0 ) {
380+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
381+ // failed to find a suitable split
381382 break ;
382383 }
383384
You can’t perform that action at this time.
0 commit comments