Skip to content

Commit 436872f

Browse files
committed
llama : rename batch_all to batch
This commit addresses the TODO in the code to rename the `batch_all` parameter to `batch` in `llama_decode_internal`.
1 parent c21a896 commit 436872f

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

src/llama.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14447,10 +14447,10 @@ static void llama_graph_compute(
1444714447
//
1444814448
static int llama_decode_internal(
1444914449
llama_context & lctx,
14450-
llama_batch batch_all) { // TODO: rename back to batch
14450+
llama_batch batch) {
1445114451

1445214452
lctx.is_encoding = false;
14453-
const uint32_t n_tokens_all = batch_all.n_tokens;
14453+
const uint32_t n_tokens_all = batch.n_tokens;
1445414454

1445514455
if (n_tokens_all == 0) {
1445614456
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
@@ -14461,7 +14461,7 @@ static int llama_decode_internal(
1446114461
const auto & hparams = model.hparams;
1446214462
const auto & cparams = lctx.cparams;
1446314463

14464-
GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
14464+
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
1446514465

1446614466
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
1446714467

@@ -14492,9 +14492,9 @@ static int llama_decode_internal(
1449214492
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
1449314493

1449414494
// count outputs
14495-
if (batch_all.logits && !embd_pooled) {
14495+
if (batch.logits && !embd_pooled) {
1449614496
for (uint32_t i = 0; i < n_tokens_all; ++i) {
14497-
n_outputs += batch_all.logits[i] != 0;
14497+
n_outputs += batch.logits[i] != 0;
1449814498
}
1449914499
} else if (lctx.logits_all || embd_pooled) {
1450014500
n_outputs = n_tokens_all;
@@ -14510,10 +14510,10 @@ static int llama_decode_internal(
1451014510
};
1451114511

1451214512
// set output mappings
14513-
if (batch_all.logits) {
14513+
if (batch.logits) {
1451414514
int32_t i_logits = 0;
1451514515
for (uint32_t i = 0; i < n_tokens_all; ++i) {
14516-
if (batch_all.logits[i]) {
14516+
if (batch.logits[i]) {
1451714517
lctx.output_ids[i] = i_logits++;
1451814518
}
1451914519
}
@@ -14527,15 +14527,15 @@ static int llama_decode_internal(
1452714527
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
1452814528
llama_batch u_batch = {
1452914529
/* .n_tokens = */ (int32_t) n_tokens,
14530-
/* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
14531-
/* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr,
14532-
/* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr,
14533-
/* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr,
14534-
/* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr,
14535-
/* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr,
14536-
/* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
14537-
/* .all_pos_1 = */ batch_all.all_pos_1,
14538-
/* .all_seq_id = */ batch_all.all_seq_id,
14530+
/* .token = */ batch.token ? batch.token + cur_token : nullptr,
14531+
/* .embd = */ batch.embd ? batch.embd + cur_token*n_embd : nullptr,
14532+
/* .pos = */ batch.pos ? batch.pos + cur_token : nullptr,
14533+
/* .n_seq_id = */ batch.n_seq_id ? batch.n_seq_id + cur_token : nullptr,
14534+
/* .seq_id = */ batch.seq_id ? batch.seq_id + cur_token : nullptr,
14535+
/* .logits = */ batch.logits ? batch.logits + cur_token : nullptr,
14536+
/* .all_pos_0 = */ batch.all_pos_0 + (llama_pos) cur_token*batch.all_pos_1,
14537+
/* .all_pos_1 = */ batch.all_pos_1,
14538+
/* .all_seq_id = */ batch.all_seq_id,
1453914539
};
1454014540

1454114541
// count the outputs in this u_batch

0 commit comments

Comments
 (0)