@@ -460,9 +460,9 @@ struct llama_batch_manager_i {
460
460
461
461
virtual bool is_done () const = 0;
462
462
virtual llama_ubatch next () = 0;
463
- virtual bool prepare () = 0;
463
+ virtual bool prepare (const llama_ubatch & ubatch ) = 0;
464
464
virtual void restore () = 0;
465
- virtual void update () = 0;
465
+ virtual void update (const llama_ubatch & ubatch ) = 0;
466
466
virtual void finalize () = 0;
467
467
468
468
// TODO: might be temporary
@@ -532,7 +532,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
532
532
}
533
533
534
534
virtual llama_ubatch next () override {
535
- ubatch = llama_ubatch ();
535
+ llama_ubatch ubatch = llama_ubatch ();
536
536
537
537
const auto & cparams = lctx.cparams ;
538
538
const auto & kv_self = lctx.kv_self ;
@@ -557,7 +557,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
557
557
return ubatch;
558
558
}
559
559
560
- virtual bool prepare () override {
560
+ virtual bool prepare (const llama_ubatch & ubatch ) override {
561
561
const auto & cparams = lctx.cparams ;
562
562
const auto & hparams = lctx.model .hparams ;
563
563
const auto & batch = lctx.sbatch .batch ;
@@ -644,7 +644,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
644
644
kv_slot_restorer.restore (lctx.kv_self );
645
645
}
646
646
647
- virtual void update () override {
647
+ virtual void update (const llama_ubatch & ubatch ) override {
648
648
auto & kv_self = lctx.kv_self ;
649
649
650
650
// update the kv ring buffer
@@ -682,8 +682,6 @@ struct llama_batch_manager : public llama_batch_manager_i {
682
682
683
683
const llama_batch & batch;
684
684
685
- llama_ubatch ubatch;
686
-
687
685
llama_kv_slot_restorer kv_slot_restorer;
688
686
};
689
687
@@ -728,7 +726,7 @@ int llama_context::decode(llama_batch & inp_batch) {
728
726
while (!bman->is_done ()) {
729
727
llama_ubatch ubatch = bman->next ();
730
728
731
- if (!bman->prepare ()) {
729
+ if (!bman->prepare (ubatch )) {
732
730
LLAMA_LOG_ERROR (" %s: failed to prepare ubatch\n " , __func__);
733
731
bman->restore ();
734
732
return -3 ;
@@ -782,7 +780,7 @@ int llama_context::decode(llama_batch & inp_batch) {
782
780
}
783
781
}
784
782
785
- bman->update ();
783
+ bman->update (ubatch );
786
784
787
785
// plot the computation graph in dot format (for debugging purposes)
788
786
// if (n_past%100 == 0) {
0 commit comments