@@ -17138,10 +17138,10 @@ static void llama_graph_compute(
1713817138//
1713917139static int llama_decode_internal(
1714017140 llama_context & lctx,
17141- llama_batch batch_all ) { // TODO: rename back to batch
17141+ llama_batch batch ) {
1714217142
1714317143 lctx.is_encoding = false;
17144- const uint32_t n_tokens_all = batch_all .n_tokens;
17144+ const uint32_t n_tokens_all = batch .n_tokens;
1714517145
1714617146 if (n_tokens_all == 0) {
1714717147 LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -17152,12 +17152,12 @@ static int llama_decode_internal(
1715217152 const auto & hparams = model.hparams;
1715317153 const auto & cparams = lctx.cparams;
1715417154
17155- GGML_ASSERT((!batch_all .token && batch_all .embd) || (batch_all .token && !batch_all .embd)); // NOLINT
17155+ GGML_ASSERT((!batch .token && batch .embd) || (batch .token && !batch .embd)); // NOLINT
1715617156
17157- if (batch_all .token) {
17157+ if (batch .token) {
1715817158 for (uint32_t i = 0; i < n_tokens_all; ++i) {
17159- if (batch_all .token[i] < 0 || (uint32_t)batch_all .token[i] >= model.vocab.n_vocab) {
17160- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch_all .token[i]);
17159+ if (batch .token[i] < 0 || (uint32_t)batch .token[i] >= model.vocab.n_vocab) {
17160+ LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch .token[i]);
1716117161 return -1;
1716217162 }
1716317163 }
@@ -17188,9 +17188,9 @@ static int llama_decode_internal(
1718817188 lctx.embd_seq.clear();
1718917189
1719017190 // count outputs
17191- if (batch_all .logits && !embd_pooled) {
17191+ if (batch .logits && !embd_pooled) {
1719217192 for (uint32_t i = 0; i < n_tokens_all; ++i) {
17193- n_outputs += batch_all .logits[i] != 0;
17193+ n_outputs += batch .logits[i] != 0;
1719417194 }
1719517195 } else if (lctx.logits_all || embd_pooled) {
1719617196 n_outputs = n_tokens_all;
@@ -17199,7 +17199,7 @@ static int llama_decode_internal(
1719917199 n_outputs = 1;
1720017200 }
1720117201
17202- lctx.sbatch.from_batch(batch_all , n_embd,
17202+ lctx.sbatch.from_batch(batch , n_embd,
1720317203 /* simple_split */ !kv_self.recurrent,
1720417204 /* logits_all */ n_outputs == n_tokens_all);
1720517205
0 commit comments