Skip to content

Commit 4aabf4e

Browse files
committed
return output ID from llama_batch_ext_add/set
1 parent 86973cb commit 4aabf4e

File tree

5 files changed

+31
-25
lines changed

5 files changed

+31
-25
lines changed

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ struct common_batch {
606606
}
607607
void set_logits_last() {
608608
if (!tokens.empty()) {
609-
llama_batch_ext_set_logits_last(batch.get());
609+
llama_batch_ext_set_output_last(batch.get());
610610
tokens.back().logits = true;
611611
}
612612
}

examples/batched-bench/batched-bench.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ int main(int argc, char ** argv) {
122122
llama_batch_ext_add_text(batch, 0, i, &j, 1, false);
123123
}
124124
}
125-
llama_batch_ext_set_logits_last(batch);
125+
llama_batch_ext_set_output_last(batch);
126126

127127
const auto t_pp_start = ggml_time_us();
128128

examples/batched/batched.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ int main(int argc, char ** argv) {
131131
}
132132

133133
// llama_decode will output logits only for the last token of the prompt
134-
llama_batch_ext_set_logits_last(batch);
134+
llama_batch_ext_set_output_last(batch);
135135

136136
if (llama_decode_ext(ctx, batch) != 0) {
137137
LOG_ERR("%s: llama_decode() failed\n", __func__);

include/llama.h

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@ extern "C" {
900900
//
901901
DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one(
902902
llama_token * tokens,
903-
int32_t n_tokens), "use llama_batch_ext API instead");
903+
int32_t n_tokens), "use llama_batch_ext_init_from_text instead");
904904

905905
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
906906
// Each token can be assigned up to n_seq_max sequence ids
@@ -912,7 +912,7 @@ extern "C" {
912912
DEPRECATED(LLAMA_API struct llama_batch llama_batch_init(
913913
int32_t n_tokens,
914914
int32_t embd,
915-
int32_t n_seq_max), "use llama_batch_ext API instead");
915+
int32_t n_seq_max), "use llama_batch_ext_init instead");
916916

917917
// Frees a batch of tokens allocated with llama_batch_init()
918918
DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch),
@@ -950,28 +950,32 @@ extern "C" {
950950

951951
// Add text tokens to the batch
952952
// Return values:
953-
// 0 : success
954953
// -1 : not enough space in the batch
955954
// -2 : embd is already set, cannot add text tokens
955+
// otherwise, returns the output ID
956956
LLAMA_API int32_t llama_batch_ext_add_text(
957957
struct llama_batch_ext * batch,
958958
llama_token token,
959959
llama_pos pos,
960960
const llama_seq_id * seq_ids,
961961
size_t n_seq_ids,
962-
float logits);
962+
bool output);
963963

964-
// Set logits for the token in the ith sequence
965-
// If pos == -1, logits will be set for the all tokens
966-
// Returns -1 if the token is not in the batch
967-
LLAMA_API int32_t llama_batch_ext_set_logits(
964+
// Set output (logits/embeddings) for the token in the ith sequence
965+
// If pos == -1, output will be set for the all tokens
966+
// Return values:
967+
// -1 : the token is not in the batch
968+
// otherwise, returns the output ID
969+
LLAMA_API int32_t llama_batch_ext_set_output(
968970
struct llama_batch_ext * batch,
969971
llama_pos pos,
970972
llama_seq_id seq_id);
971973

972-
// Set logits for the last added token
973-
// Returns -1 if there is no tokens in the batch
974-
LLAMA_API int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch);
974+
// Set output (logits/embeddings) for the last added token
975+
// Return values:
976+
// -1 : the batch is empty
977+
// otherwise, returns the output ID
978+
LLAMA_API int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch);
975979

976980
// Get a "view" from a number of tokens offset
977981
// Return returned batch must be freed with llama_batch_free()

src/llama-batch.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -410,25 +410,26 @@ int32_t llama_batch_ext_add_text(
410410
llama_pos pos,
411411
const llama_seq_id * seq_ids,
412412
size_t n_seq_ids,
413-
float logits) {
413+
bool output) {
414414
if (batch->n_tokens + 1 > batch->max_tokens) {
415415
return -1; // llama_batch size exceeded
416416
}
417417
if (batch->embd) {
418418
return -2; // embd is already set, cannot add text tokens
419419
}
420-
batch->token [batch->n_tokens] = token;
421-
batch->pos [batch->n_tokens] = pos;
422-
batch->n_seq_id[batch->n_tokens] = n_seq_ids;
420+
const int32_t output_id = batch->n_tokens;
421+
batch->token [output_id] = token;
422+
batch->pos [output_id] = pos;
423+
batch->n_seq_id[output_id] = n_seq_ids;
423424
for (size_t j = 0; j < n_seq_ids; j++) {
424425
batch->seq_id[batch->n_tokens][j] = seq_ids[j];
425426
}
426-
batch->logits [batch->n_tokens] = logits;
427+
batch->logits [output_id] = output;
427428
batch->n_tokens++;
428-
return 0;
429+
return output_id;
429430
}
430431

431-
int32_t llama_batch_ext_set_logits(
432+
int32_t llama_batch_ext_set_output(
432433
struct llama_batch_ext * batch,
433434
llama_pos pos,
434435
llama_seq_id seq_id) {
@@ -439,20 +440,21 @@ int32_t llama_batch_ext_set_logits(
439440
// found the sequence
440441
if (pos == -1 || pos == batch->pos[i]) {
441442
batch->logits[i] = true;
442-
return 0;
443+
return i;
443444
}
444445
}
445446
}
446447
}
447448
return -1; // not found
448449
}
449450

450-
int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch) {
451+
int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch) {
451452
if (batch->n_tokens == 0) {
452453
return -1;
453454
}
454-
batch->logits[batch->n_tokens - 1] = true;
455-
return 0;
455+
const int32_t output_id = batch->n_tokens - 1;
456+
batch->logits[output_id] = true;
457+
return output_id;
456458
}
457459

458460
void llama_batch_ext_clear(struct llama_batch_ext * batch) {

0 commit comments

Comments
 (0)