@@ -460,9 +460,9 @@ struct llama_batch_manager_i {
460460
461461 virtual bool is_done () const = 0;
462462 virtual llama_ubatch next () = 0;
463- virtual bool prepare () = 0;
463+ virtual bool prepare (const llama_ubatch & ubatch ) = 0;
464464 virtual void restore () = 0;
465- virtual void update () = 0;
465+ virtual void update (const llama_ubatch & ubatch ) = 0;
466466 virtual void finalize () = 0;
467467
468468 // TODO: might be temporary
@@ -532,7 +532,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
532532 }
533533
534534 virtual llama_ubatch next () override {
535- ubatch = llama_ubatch ();
535+ llama_ubatch ubatch = llama_ubatch ();
536536
537537 const auto & cparams = lctx.cparams ;
538538 const auto & kv_self = lctx.kv_self ;
@@ -557,7 +557,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
557557 return ubatch;
558558 }
559559
560- virtual bool prepare () override {
560+ virtual bool prepare (const llama_ubatch & ubatch ) override {
561561 const auto & cparams = lctx.cparams ;
562562 const auto & hparams = lctx.model .hparams ;
563563 const auto & batch = lctx.sbatch .batch ;
@@ -644,7 +644,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
644644 kv_slot_restorer.restore (lctx.kv_self );
645645 }
646646
647- virtual void update () override {
647+ virtual void update (const llama_ubatch & ubatch ) override {
648648 auto & kv_self = lctx.kv_self ;
649649
650650 // update the kv ring buffer
@@ -682,8 +682,6 @@ struct llama_batch_manager : public llama_batch_manager_i {
682682
683683 const llama_batch & batch;
684684
685- llama_ubatch ubatch;
686-
687685 llama_kv_slot_restorer kv_slot_restorer;
688686};
689687
@@ -728,7 +726,7 @@ int llama_context::decode(llama_batch & inp_batch) {
728726 while (!bman->is_done ()) {
729727 llama_ubatch ubatch = bman->next ();
730728
731- if (!bman->prepare ()) {
729+ if (!bman->prepare (ubatch )) {
732730 LLAMA_LOG_ERROR (" %s: failed to prepare ubatch\n " , __func__);
733731 bman->restore ();
734732 return -3 ;
@@ -782,7 +780,7 @@ int llama_context::decode(llama_batch & inp_batch) {
782780 }
783781 }
784782
785- bman->update ();
783+ bman->update (ubatch );
786784
787785 // plot the computation graph in dot format (for debugging purposes)
788786 // if (n_past%100 == 0) {
0 commit comments