@@ -17149,10 +17149,10 @@ static void llama_graph_compute(
1714917149//
1715017150static int llama_decode_internal(
1715117151 llama_context & lctx,
17152- llama_batch batch_all ) { // TODO: rename back to batch
17152+ llama_batch batch ) {
1715317153
1715417154 lctx.is_encoding = false;
17155- const uint32_t n_tokens_all = batch_all .n_tokens;
17155+ const uint32_t n_tokens_all = batch .n_tokens;
1715617156
1715717157 if (n_tokens_all == 0) {
1715817158 LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -17163,12 +17163,12 @@ static int llama_decode_internal(
1716317163 const auto & hparams = model.hparams;
1716417164 const auto & cparams = lctx.cparams;
1716517165
17166- GGML_ASSERT((!batch_all .token && batch_all .embd) || (batch_all .token && !batch_all .embd)); // NOLINT
17166+ GGML_ASSERT((!batch .token && batch .embd) || (batch .token && !batch .embd)); // NOLINT
1716717167
17168- if (batch_all .token) {
17168+ if (batch .token) {
1716917169 for (uint32_t i = 0; i < n_tokens_all; ++i) {
17170- if (batch_all .token[i] < 0 || (uint32_t)batch_all .token[i] >= model.vocab.n_vocab) {
17171- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch_all .token[i]);
17170+ if (batch .token[i] < 0 || (uint32_t)batch .token[i] >= model.vocab.n_vocab) {
17171+ LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch .token[i]);
1717217172 return -1;
1717317173 }
1717417174 }
@@ -17199,9 +17199,9 @@ static int llama_decode_internal(
1719917199 lctx.embd_seq.clear();
1720017200
1720117201 // count outputs
17202- if (batch_all .logits && !embd_pooled) {
17202+ if (batch .logits && !embd_pooled) {
1720317203 for (uint32_t i = 0; i < n_tokens_all; ++i) {
17204- n_outputs += batch_all .logits[i] != 0;
17204+ n_outputs += batch .logits[i] != 0;
1720517205 }
1720617206 } else if (lctx.logits_all || embd_pooled) {
1720717207 n_outputs = n_tokens_all;
@@ -17210,7 +17210,7 @@ static int llama_decode_internal(
1721017210 n_outputs = 1;
1721117211 }
1721217212
17213- lctx.sbatch.from_batch(batch_all , n_embd,
17213+ lctx.sbatch.from_batch(batch , n_embd,
1721417214 /* simple_split */ !kv_self.recurrent,
1721517215 /* logits_all */ n_outputs == n_tokens_all);
1721617216
0 commit comments