@@ -17134,10 +17134,10 @@ static void llama_graph_compute(
1713417134//
1713517135static int llama_decode_internal(
1713617136 llama_context & lctx,
17137- llama_batch batch_all ) { // TODO: rename back to batch
17137+ llama_batch batch ) {
1713817138
1713917139 lctx.is_encoding = false;
17140- const uint32_t n_tokens_all = batch_all .n_tokens;
17140+ const uint32_t n_tokens_all = batch .n_tokens;
1714117141
1714217142 if (n_tokens_all == 0) {
1714317143 LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -17148,12 +17148,12 @@ static int llama_decode_internal(
1714817148 const auto & hparams = model.hparams;
1714917149 const auto & cparams = lctx.cparams;
1715017150
17151- GGML_ASSERT((!batch_all .token && batch_all .embd) || (batch_all .token && !batch_all .embd)); // NOLINT
17151+ GGML_ASSERT((!batch .token && batch .embd) || (batch .token && !batch .embd)); // NOLINT
1715217152
17153- if (batch_all .token) {
17153+ if (batch .token) {
1715417154 for (uint32_t i = 0; i < n_tokens_all; ++i) {
17155- if (batch_all .token[i] < 0 || (uint32_t)batch_all .token[i] >= model.vocab.n_vocab) {
17156- LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch_all .token[i]);
17155+ if (batch .token[i] < 0 || (uint32_t)batch .token[i] >= model.vocab.n_vocab) {
17156+ LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch .token[i]);
1715717157 return -1;
1715817158 }
1715917159 }
@@ -17184,9 +17184,9 @@ static int llama_decode_internal(
1718417184 lctx.embd_seq.clear();
1718517185
1718617186 // count outputs
17187- if (batch_all .logits && !embd_pooled) {
17187+ if (batch .logits && !embd_pooled) {
1718817188 for (uint32_t i = 0; i < n_tokens_all; ++i) {
17189- n_outputs += batch_all .logits[i] != 0;
17189+ n_outputs += batch .logits[i] != 0;
1719017190 }
1719117191 } else if (lctx.logits_all || embd_pooled) {
1719217192 n_outputs = n_tokens_all;
@@ -17195,7 +17195,7 @@ static int llama_decode_internal(
1719517195 n_outputs = 1;
1719617196 }
1719717197
17198- lctx.sbatch.from_batch(batch_all , n_embd,
17198+ lctx.sbatch.from_batch(batch , n_embd,
1719917199 /* simple_split */ !kv_self.recurrent,
1720017200 /* logits_all */ n_outputs == n_tokens_all);
1720117201
0 commit comments