Skip to content

Commit 145401c

Browse files
committed
context : fix logits size overflow for huge batches
1 parent f16a843 commit 145401c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/llama-context.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
786786
const auto & hparams = model.hparams;
787787

788788
const int64_t n_embd = hparams.n_embd;
789-
const int32_t n_vocab = model.vocab.n_tokens();
789+
const int64_t n_vocab = model.vocab.n_tokens();
790790

791791
// note: during encode, we always pass the full sequence starting from pos = 0
792792
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
@@ -959,7 +959,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
959959
const auto & vocab = model.vocab;
960960
const auto & hparams = model.hparams;
961961

962-
const int32_t n_vocab = vocab.n_tokens();
962+
const int64_t n_vocab = vocab.n_tokens();
963963
const int64_t n_embd = hparams.n_embd;
964964

965965
// when computing embeddings, all tokens are output

0 commit comments

Comments
 (0)