Skip to content

Commit 9fb2d81

Browse files
committed
fix common_batch missing seq_id
1 parent 47086fa commit 9fb2d81

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

common/common.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ struct common_batch {
586586
llama_batch_ext_ptr batch;
587587
struct batch_token {
588588
llama_token token;
589+
llama_seq_id seq_id; // only support single seq for now
589590
bool logits;
590591
};
591592
std::vector<batch_token> tokens;
@@ -601,14 +602,14 @@ struct common_batch {
601602
}
602603
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
603604
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
604-
tokens.push_back({token, logits});
605+
tokens.push_back({token, seq_id, logits});
605606
if (logits) {
606607
n_outputs++;
607608
}
608609
}
609610
void add_text(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
610611
llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits);
611-
tokens.push_back({token, logits});
612+
tokens.push_back({token, seq_ids[0], logits});
612613
if (logits) {
613614
n_outputs++;
614615
}

0 commit comments

Comments
 (0)