File tree Expand file tree Collapse file tree 6 files changed +34
-1
lines changed Expand file tree Collapse file tree 6 files changed +34
-1
lines changed Original file line number Diff line number Diff line change @@ -411,6 +411,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
411411 return n_outputs;
412412}
413413
414+ uint32_t llama_batch_allocr::get_n_used () const {
415+ return n_used;
416+ }
417+
414418std::vector<int32_t > & llama_batch_allocr::get_out_ids () {
415419 return out_ids;
416420}
@@ -426,6 +430,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
426430void llama_batch_allocr::split_reset () {
427431 out_ids.clear ();
428432
433+ n_used = 0 ;
434+
429435 used.clear ();
430436 used.resize (get_n_tokens (), false );
431437
@@ -450,6 +456,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
450456 idxs.push_back (cur_idx);
451457
452458 used[cur_idx] = true ;
459+ ++n_used;
453460
454461 ++cur_idx;
455462
@@ -535,6 +542,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
535542 idxs_per_seq[s].push_back (idx);
536543
537544 used[idx] = true ;
545+ ++n_used;
538546
539547 ++cur_idx[s];
540548 }
@@ -576,6 +584,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
576584 idxs.push_back (cur_idx);
577585
578586 used[cur_idx] = true ;
587+ ++n_used;
579588
580589 if (idxs.size () >= n_ubatch) {
581590 break ;
Original file line number Diff line number Diff line change @@ -56,6 +56,7 @@ class llama_batch_allocr {
5656
5757 uint32_t get_n_tokens () const ;
5858 uint32_t get_n_outputs () const ;
59+ uint32_t get_n_used () const ;
5960
6061 // the array of output indices in the order they were encountered during the ubatch splitting
6162 std::vector<int32_t > & get_out_ids ();
@@ -127,6 +128,8 @@ class llama_batch_allocr {
127128 // batch indices of the output
128129 std::vector<int32_t > out_ids;
129130
131+ uint32_t n_used;
132+
130133 // used[i] indicates if token i has already been used in a previous ubatch
131134 std::vector<bool > used;
132135
Original file line number Diff line number Diff line change @@ -113,6 +113,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
113113 ubatches.push_back (std::move (ubatch)); // NOLINT
114114 }
115115
116+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
117+ // failed to find a suitable split
118+ break ;
119+ }
120+
116121 auto sinfos_base = kv_base->prepare (ubatches);
117122 if (sinfos_base.empty ()) {
118123 break ;
@@ -144,6 +149,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
144149 ubatches.push_back (std::move (ubatch)); // NOLINT
145150 }
146151
152+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
153+ // failed to find a suitable split
154+ break ;
155+ }
156+
147157 auto sinfos_base = kv_base->prepare (ubatches);
148158 if (sinfos_base.empty ()) {
149159 break ;
Original file line number Diff line number Diff line change @@ -353,6 +353,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
353353 ubatches.push_back (std::move (ubatch)); // NOLINT
354354 }
355355
356+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
357+ // failed to find a suitable split
358+ break ;
359+ }
360+
356361 auto sinfos = prepare (ubatches);
357362 if (sinfos.empty ()) {
358363 break ;
Original file line number Diff line number Diff line change @@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
8080 ubatches.push_back (std::move (ubatch)); // NOLINT
8181 }
8282
83+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
84+ // failed to find a suitable split
85+ break ;
86+ }
87+
8388 // prepare the recurrent batches first
8489 if (!mem_recr->prepare (ubatches)) {
8590 // TODO: will the recurrent cache be in an undefined context at this point?
Original file line number Diff line number Diff line change @@ -377,7 +377,8 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
377377 ubatch = balloc.split_equal (n_ubatch);
378378 }
379379
380- if (ubatch.n_tokens == 0 ) {
380+ if (balloc.get_n_used () < balloc.get_n_tokens ()) {
381+ // failed to find a suitable split
381382 break ;
382383 }
383384
You can’t perform that action at this time.
0 commit comments