Skip to content

Commit f16a843

Browse files
committed
context : fix overflow when re-ordering huge outputs
1 parent ec428b0 commit f16a843

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/llama-context.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,21 +1328,21 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
13281328
}
13291329

13301330
void llama_context::output_reorder() {
1331-
const uint32_t n_vocab = model.vocab.n_tokens();
1331+
const uint64_t n_vocab = model.vocab.n_tokens();
13321332
const uint64_t n_embd = model.hparams.n_embd;
13331333

1334-
for (uint32_t s = 0; s < output_swaps.size(); ++s) {
1335-
const uint32_t i0 = output_swaps[s].i0;
1336-
const uint32_t i1 = output_swaps[s].i1;
1334+
for (size_t s = 0; s < output_swaps.size(); ++s) {
1335+
const uint64_t i0 = output_swaps[s].i0;
1336+
const uint64_t i1 = output_swaps[s].i1;
13371337

13381338
if (logits_size > 0) {
1339-
for (uint32_t k = 0; k < n_vocab; k++) {
1339+
for (uint64_t k = 0; k < n_vocab; k++) {
13401340
std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
13411341
}
13421342
}
13431343

13441344
if (embd_size > 0) {
1345-
for (uint32_t k = 0; k < n_embd; k++) {
1345+
for (uint64_t k = 0; k < n_embd; k++) {
13461346
std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]);
13471347
}
13481348
}

0 commit comments

Comments
 (0)