@@ -363,30 +363,35 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
363363}
364364
365365llama_memory_context_ptr llama_memory_recurrent::init_batch (llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
366- std::vector<llama_ubatch> ubatches;
366+ do {
367+ balloc.split_reset ();
367368
368- while (true ) {
369- llama_ubatch ubatch;
369+ std::vector<llama_ubatch> ubatches;
370+ while (true ) {
371+ llama_ubatch ubatch;
370372
371- if (embd_all) {
372- // if all tokens are output, split by sequence
373- ubatch = balloc.split_seq (n_ubatch);
374- } else {
375- ubatch = balloc.split_equal (n_ubatch);
373+ if (embd_all) {
374+ // if all tokens are output, split by sequence
375+ ubatch = balloc.split_seq (n_ubatch);
376+ } else {
377+ ubatch = balloc.split_equal (n_ubatch);
378+ }
379+
380+ if (ubatch.n_tokens == 0 ) {
381+ break ;
382+ }
383+
384+ ubatches.push_back (std::move (ubatch)); // NOLINT
376385 }
377386
378- if (ubatch. n_tokens == 0 ) {
387+ if (! prepare (ubatches) ) {
379388 break ;
380389 }
381390
382- ubatches.push_back (std::move (ubatch)); // NOLINT
383- }
384-
385- if (!prepare (ubatches)) {
386- return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
387- }
391+ return std::make_unique<llama_memory_recurrent_context>(this , std::move (ubatches));
392+ } while (false );
388393
389- return std::make_unique<llama_memory_recurrent_context>(this , std::move (ubatches) );
394+ return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE );
390395}
391396
392397llama_memory_context_ptr llama_memory_recurrent::init_full () {
0 commit comments