@@ -14447,10 +14447,10 @@ static void llama_graph_compute(
1444714447//
1444814448static int llama_decode_internal(
1444914449 llama_context & lctx,
14450- llama_batch batch_all ) { // TODO: rename back to batch
14450+ llama_batch batch ) {
1445114451
1445214452 lctx.is_encoding = false;
14453- const uint32_t n_tokens_all = batch_all .n_tokens;
14453+ const uint32_t n_tokens_all = batch .n_tokens;
1445414454
1445514455 if (n_tokens_all == 0) {
1445614456 LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
@@ -14461,7 +14461,7 @@ static int llama_decode_internal(
1446114461 const auto & hparams = model.hparams;
1446214462 const auto & cparams = lctx.cparams;
1446314463
14464- GGML_ASSERT((!batch_all .token && batch_all .embd) || (batch_all .token && !batch_all .embd)); // NOLINT
14464+ GGML_ASSERT((!batch .token && batch .embd) || (batch .token && !batch .embd)); // NOLINT
1446514465
1446614466 GGML_ASSERT(n_tokens_all <= cparams.n_batch);
1446714467
@@ -14492,9 +14492,9 @@ static int llama_decode_internal(
1449214492 const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
1449314493
1449414494 // count outputs
14495- if (batch_all .logits && !embd_pooled) {
14495+ if (batch .logits && !embd_pooled) {
1449614496 for (uint32_t i = 0; i < n_tokens_all; ++i) {
14497- n_outputs += batch_all .logits[i] != 0;
14497+ n_outputs += batch .logits[i] != 0;
1449814498 }
1449914499 } else if (lctx.logits_all || embd_pooled) {
1450014500 n_outputs = n_tokens_all;
@@ -14510,10 +14510,10 @@ static int llama_decode_internal(
1451014510 };
1451114511
1451214512 // set output mappings
14513- if (batch_all .logits) {
14513+ if (batch .logits) {
1451414514 int32_t i_logits = 0;
1451514515 for (uint32_t i = 0; i < n_tokens_all; ++i) {
14516- if (batch_all .logits[i]) {
14516+ if (batch .logits[i]) {
1451714517 lctx.output_ids[i] = i_logits++;
1451814518 }
1451914519 }
@@ -14527,15 +14527,15 @@ static int llama_decode_internal(
1452714527 const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
1452814528 llama_batch u_batch = {
1452914529 /* .n_tokens = */ (int32_t) n_tokens,
14530- /* .token = */ batch_all .token ? batch_all .token + cur_token : nullptr,
14531- /* .embd = */ batch_all .embd ? batch_all .embd + cur_token*n_embd : nullptr,
14532- /* .pos = */ batch_all .pos ? batch_all .pos + cur_token : nullptr,
14533- /* .n_seq_id = */ batch_all .n_seq_id ? batch_all .n_seq_id + cur_token : nullptr,
14534- /* .seq_id = */ batch_all .seq_id ? batch_all .seq_id + cur_token : nullptr,
14535- /* .logits = */ batch_all .logits ? batch_all .logits + cur_token : nullptr,
14536- /* .all_pos_0 = */ batch_all .all_pos_0 + (llama_pos) cur_token*batch_all .all_pos_1,
14537- /* .all_pos_1 = */ batch_all .all_pos_1,
14538- /* .all_seq_id = */ batch_all .all_seq_id,
14530+ /* .token = */ batch .token ? batch .token + cur_token : nullptr,
14531+ /* .embd = */ batch .embd ? batch .embd + cur_token*n_embd : nullptr,
14532+ /* .pos = */ batch .pos ? batch .pos + cur_token : nullptr,
14533+ /* .n_seq_id = */ batch .n_seq_id ? batch .n_seq_id + cur_token : nullptr,
14534+ /* .seq_id = */ batch .seq_id ? batch .seq_id + cur_token : nullptr,
14535+ /* .logits = */ batch .logits ? batch .logits + cur_token : nullptr,
14536+ /* .all_pos_0 = */ batch .all_pos_0 + (llama_pos) cur_token*batch .all_pos_1,
14537+ /* .all_pos_1 = */ batch .all_pos_1,
14538+ /* .all_seq_id = */ batch .all_seq_id,
1453914539 };
1454014540
1454114541 // count the outputs in this u_batch
0 commit comments