@@ -786,7 +786,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
786
786
const auto & hparams = model.hparams ;
787
787
788
788
const int64_t n_embd = hparams.n_embd ;
789
- const int32_t n_vocab = model.vocab .n_tokens ();
789
+ const int64_t n_vocab = model.vocab .n_tokens ();
790
790
791
791
// note: during encode, we always pass the full sequence starting from pos = 0
792
792
if (!balloc->init (batch_inp, model.vocab , nullptr , n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max , true )) {
@@ -959,7 +959,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
959
959
const auto & vocab = model.vocab ;
960
960
const auto & hparams = model.hparams ;
961
961
962
- const int32_t n_vocab = vocab.n_tokens ();
962
+ const int64_t n_vocab = vocab.n_tokens ();
963
963
const int64_t n_embd = hparams.n_embd ;
964
964
965
965
// when computing embeddings, all tokens are output
@@ -1328,21 +1328,21 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
1328
1328
}
1329
1329
1330
1330
void llama_context::output_reorder () {
1331
- const uint32_t n_vocab = model.vocab .n_tokens ();
1331
+ const uint64_t n_vocab = model.vocab .n_tokens ();
1332
1332
const uint64_t n_embd = model.hparams .n_embd ;
1333
1333
1334
- for (uint32_t s = 0 ; s < output_swaps.size (); ++s) {
1335
- const uint32_t i0 = output_swaps[s].i0 ;
1336
- const uint32_t i1 = output_swaps[s].i1 ;
1334
+ for (size_t s = 0 ; s < output_swaps.size (); ++s) {
1335
+ const uint64_t i0 = output_swaps[s].i0 ;
1336
+ const uint64_t i1 = output_swaps[s].i1 ;
1337
1337
1338
1338
if (logits_size > 0 ) {
1339
- for (uint32_t k = 0 ; k < n_vocab; k++) {
1339
+ for (uint64_t k = 0 ; k < n_vocab; k++) {
1340
1340
std::swap (logits[i0*n_vocab + k], logits[i1*n_vocab + k]);
1341
1341
}
1342
1342
}
1343
1343
1344
1344
if (embd_size > 0 ) {
1345
- for (uint32_t k = 0 ; k < n_embd; k++) {
1345
+ for (uint64_t k = 0 ; k < n_embd; k++) {
1346
1346
std::swap (embd[i0*n_embd + k], embd[i1*n_embd + k]);
1347
1347
}
1348
1348
}
0 commit comments