Skip to content

Commit d1d8d53

Browse files
committed
bman : remove ubatch member
ggml-ci
1 parent ef358ee commit d1d8d53

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

src/llama-context.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,9 @@ struct llama_batch_manager_i {
460460

461461
virtual bool is_done() const = 0;
462462
virtual llama_ubatch next() = 0;
463-
virtual bool prepare() = 0;
463+
virtual bool prepare(const llama_ubatch & ubatch) = 0;
464464
virtual void restore() = 0;
465-
virtual void update() = 0;
465+
virtual void update(const llama_ubatch & ubatch) = 0;
466466
virtual void finalize() = 0;
467467

468468
// TODO: might be temporary
@@ -532,7 +532,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
532532
}
533533

534534
virtual llama_ubatch next() override {
535-
ubatch = llama_ubatch();
535+
llama_ubatch ubatch = llama_ubatch();
536536

537537
const auto & cparams = lctx.cparams;
538538
const auto & kv_self = lctx.kv_self;
@@ -557,7 +557,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
557557
return ubatch;
558558
}
559559

560-
virtual bool prepare() override {
560+
virtual bool prepare(const llama_ubatch & ubatch) override {
561561
const auto & cparams = lctx.cparams;
562562
const auto & hparams = lctx.model.hparams;
563563
const auto & batch = lctx.sbatch.batch;
@@ -644,7 +644,7 @@ struct llama_batch_manager : public llama_batch_manager_i {
644644
kv_slot_restorer.restore(lctx.kv_self);
645645
}
646646

647-
virtual void update() override {
647+
virtual void update(const llama_ubatch & ubatch) override {
648648
auto & kv_self = lctx.kv_self;
649649

650650
// update the kv ring buffer
@@ -682,8 +682,6 @@ struct llama_batch_manager : public llama_batch_manager_i {
682682

683683
const llama_batch & batch;
684684

685-
llama_ubatch ubatch;
686-
687685
llama_kv_slot_restorer kv_slot_restorer;
688686
};
689687

@@ -728,7 +726,7 @@ int llama_context::decode(llama_batch & inp_batch) {
728726
while (!bman->is_done()) {
729727
llama_ubatch ubatch = bman->next();
730728

731-
if (!bman->prepare()) {
729+
if (!bman->prepare(ubatch)) {
732730
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
733731
bman->restore();
734732
return -3;
@@ -782,7 +780,7 @@ int llama_context::decode(llama_batch & inp_batch) {
782780
}
783781
}
784782

785-
bman->update();
783+
bman->update(ubatch);
786784

787785
// plot the computation graph in dot format (for debugging purposes)
788786
//if (n_past%100 == 0) {

0 commit comments

Comments
 (0)