Skip to content

Commit 4ed4fe7

Browse files
committed
first proposal for private llama_batch
1 parent 04045bb commit 4ed4fe7

File tree

3 files changed

+171
-49
lines changed

3 files changed

+171
-49
lines changed

include/llama.h

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -231,29 +231,7 @@ extern "C" {
231231

232232
typedef bool (*llama_progress_callback)(float progress, void * user_data);
233233

234-
// Input data for llama_decode
235-
// A llama_batch object can contain input about one or many sequences
236-
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
237-
//
238-
// - token : the token ids of the input (used when embd is NULL)
239-
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
240-
// - pos : the positions of the respective token in the sequence
241-
// (if set to NULL, the token position will be tracked automatically by llama_decode)
242-
// - seq_id : the sequence to which the respective token belongs
243-
// (if set to NULL, the sequence ID will be assumed to be 0)
244-
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
245-
// (if set to NULL, only the logits for last token will be returned)
246-
//
247-
typedef struct llama_batch {
248-
int32_t n_tokens;
249-
250-
llama_token * token;
251-
float * embd;
252-
llama_pos * pos;
253-
int32_t * n_seq_id;
254-
llama_seq_id ** seq_id;
255-
int8_t * logits; // TODO: rename this to "output"
256-
} llama_batch;
234+
struct llama_batch;
257235

258236
enum llama_model_kv_override_type {
259237
LLAMA_KV_OVERRIDE_TYPE_INT,
@@ -829,7 +807,7 @@ extern "C" {
829807
//
830808
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
831809
//
832-
LLAMA_API struct llama_batch llama_batch_get_one(
810+
LLAMA_API struct llama_batch * llama_batch_get_one(
833811
llama_token * tokens,
834812
int32_t n_tokens);
835813

@@ -840,13 +818,59 @@ extern "C" {
840818
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
841819
// The rest of the llama_batch members are allocated with size n_tokens
842820
// All members are left uninitialized
843-
LLAMA_API struct llama_batch llama_batch_init(
821+
// LLAMA_API struct llama_batch llama_batch_init(
822+
// int32_t n_tokens,
823+
// int32_t embd,
824+
// int32_t n_seq_max);
825+
826+
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
827+
// Each token can be assigned up to n_seq_max sequence ids
828+
// The batch has to be freed with llama_batch_free()
829+
LLAMA_API struct llama_batch * llama_batch_init(
844830
int32_t n_tokens,
845-
int32_t embd,
846831
int32_t n_seq_max);
847832

833+
// Same with llama_batch_init, but initializes the batch with the provided raw embeddings
834+
LLAMA_API struct llama_batch * llama_batch_init_from_embd(
835+
float * embd,
836+
size_t n_embd,
837+
int32_t pos0,
838+
int32_t seq_id);
839+
840+
// Add text tokens to the batch
841+
// First token in the list starts at position pos0
842+
// Return values:
843+
// 0 : success
844+
// -1 : not enough space in the batch
845+
// -2 : embd is already set, cannot add text tokens
846+
LLAMA_API int32_t llama_batch_add_text(
847+
struct llama_batch * batch,
848+
llama_token * tokens,
849+
size_t n_tokens,
850+
int32_t pos0,
851+
int32_t seq_id);
852+
853+
// Same as llama_batch_add_text, but accepts multiple sequences
854+
LLAMA_API int32_t llama_batch_add_text(
855+
struct llama_batch * batch,
856+
llama_token * tokens,
857+
size_t n_tokens,
858+
int32_t pos0,
859+
int32_t * seq_ids,
860+
size_t n_seq_ids);
861+
862+
// Set logits for the token in the ith sequence
863+
// If pos == -1, logits will be set for the all tokens
864+
LLAMA_API int32_t llama_batch_set_logits(
865+
struct llama_batch * batch,
866+
int32_t pos,
867+
int32_t seq_id);
868+
869+
// Remove everything from the batch
870+
LLAMA_API void llama_batch_clear(struct llama_batch * batch);
871+
848872
// Frees a batch of tokens allocated with llama_batch_init()
849-
LLAMA_API void llama_batch_free(struct llama_batch batch);
873+
LLAMA_API void llama_batch_free(struct llama_batch * batch);
850874

851875
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
852876
// Stores the encoder output internally for later use by the decoder cross-attention layers.

src/llama-batch.cpp

Lines changed: 96 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
309309
// interface implementation
310310
//
311311

312-
struct llama_batch llama_batch_get_one(
312+
struct llama_batch * llama_batch_get_one(
313313
llama_token * tokens,
314314
int32_t n_tokens) {
315-
return {
315+
return new llama_batch{
316316
/*n_tokens =*/ n_tokens,
317317
/*tokens =*/ tokens,
318318
/*embd =*/ nullptr,
@@ -323,8 +323,8 @@ struct llama_batch llama_batch_get_one(
323323
};
324324
}
325325

326-
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327-
llama_batch batch = {
326+
static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
327+
llama_batch * batch = new llama_batch{
328328
/*n_tokens =*/ 0,
329329
/*tokens =*/ nullptr,
330330
/*embd =*/ nullptr,
@@ -335,34 +335,108 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
335335
};
336336

337337
if (embd) {
338-
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
338+
batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
339339
} else {
340-
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
340+
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
341341
}
342342

343-
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
344-
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
345-
batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
343+
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
344+
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
345+
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
346346
for (int i = 0; i < n_tokens_alloc; ++i) {
347-
batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
347+
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
348348
}
349-
batch.seq_id[n_tokens_alloc] = nullptr;
349+
batch->seq_id[n_tokens_alloc] = nullptr;
350350

351-
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
351+
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
352352

353353
return batch;
354354
}
355355

356-
void llama_batch_free(struct llama_batch batch) {
357-
if (batch.token) free(batch.token);
358-
if (batch.embd) free(batch.embd);
359-
if (batch.pos) free(batch.pos);
360-
if (batch.n_seq_id) free(batch.n_seq_id);
361-
if (batch.seq_id) {
362-
for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
363-
free(batch.seq_id[i]);
356+
struct llama_batch * llama_batch_init(int32_t n_tokens_alloc, int32_t n_seq_max) {
357+
return llama_batch_init_impl(n_tokens_alloc, 0, n_seq_max);
358+
}
359+
360+
struct llama_batch * llama_batch_init_from_embd(
361+
float * embd,
362+
size_t n_embd,
363+
int32_t pos0,
364+
int32_t seq_id) {
365+
struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1);
366+
memcpy(batch->embd, embd, n_embd * sizeof(float));
367+
for (int32_t i = 0; i < n_embd; i++) {
368+
batch->pos [i] = pos0 + i;
369+
batch->n_seq_id[i] = 1;
370+
batch->seq_id [i][0] = seq_id;
371+
}
372+
}
373+
374+
int32_t llama_batch_add_text(
375+
struct llama_batch * batch,
376+
llama_token * tokens,
377+
size_t n_tokens,
378+
int32_t pos0,
379+
int32_t * seq_ids,
380+
size_t n_seq_ids) {
381+
if (batch->n_tokens + n_tokens > batch->n_tokens) {
382+
return -1;
383+
}
384+
if (batch->embd) {
385+
return -2;
386+
}
387+
for (int32_t i = 0; i < n_tokens; i++) {
388+
batch->token [batch->n_tokens + i] = tokens[i];
389+
batch->pos [batch->n_tokens + i] = pos0 + i;
390+
batch->n_seq_id[batch->n_tokens + i] = n_seq_ids;
391+
for (int32_t j = 0; j < n_seq_ids; j++) {
392+
batch->seq_id[batch->n_tokens + i][j] = seq_ids[j];
393+
}
394+
}
395+
}
396+
397+
int32_t llama_batch_add_text(
398+
struct llama_batch * batch,
399+
llama_token * tokens,
400+
size_t n_tokens,
401+
int32_t pos0,
402+
int32_t seq_id) {
403+
std::array<int32_t, 1> seq_ids = { seq_id };
404+
return llama_batch_add_text(batch, tokens, n_tokens, pos0, seq_ids.data(), seq_ids.size());
405+
}
406+
407+
int32_t llama_batch_set_logits(
408+
struct llama_batch * batch,
409+
int32_t pos,
410+
int32_t seq_id) {
411+
for (int32_t i = 0; i < batch->n_tokens; i++) {
412+
// find the token having seq_id
413+
for (int32_t j = 0; j < batch->n_seq_id[i]; j++) {
414+
if (batch->seq_id[i][j] == seq_id) {
415+
// found the sequence
416+
if (pos == -1 || pos == batch->pos[i]) {
417+
batch->logits[i] = true;
418+
break;
419+
}
420+
}
421+
}
422+
}
423+
}
424+
425+
void llama_batch_clear(struct llama_batch * batch) {
426+
batch->n_tokens = 0;
427+
}
428+
429+
void llama_batch_free(struct llama_batch * batch) {
430+
if (batch->token) free(batch->token);
431+
if (batch->embd) free(batch->embd);
432+
if (batch->pos) free(batch->pos);
433+
if (batch->n_seq_id) free(batch->n_seq_id);
434+
if (batch->seq_id) {
435+
for (int i = 0; batch->seq_id[i] != nullptr; ++i) {
436+
free(batch->seq_id[i]);
364437
}
365-
free(batch.seq_id);
438+
free(batch->seq_id);
366439
}
367-
if (batch.logits) free(batch.logits);
440+
if (batch->logits) free(batch->logits);
441+
delete batch;
368442
}

src/llama-batch.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,30 @@
55
#include <array>
66
#include <vector>
77

8+
// Input data for llama_decode
9+
// A llama_batch object can contain input about one or many sequences
10+
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
11+
//
12+
// - token : the token ids of the input (used when embd is NULL)
13+
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
14+
// - pos : the positions of the respective token in the sequence
15+
// (if set to NULL, the token position will be tracked automatically by llama_decode)
16+
// - seq_id : the sequence to which the respective token belongs
17+
// (if set to NULL, the sequence ID will be assumed to be 0)
18+
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
19+
// (if set to NULL, only the logits for last token will be returned)
20+
//
21+
struct llama_batch {
22+
int32_t n_tokens;
23+
24+
llama_token * token;
25+
float * embd;
26+
llama_pos * pos;
27+
int32_t * n_seq_id;
28+
llama_seq_id ** seq_id;
29+
int8_t * logits; // TODO: rename this to "output"
30+
};
31+
832
// very similar to llama_batch,
933
// but has more metadata about sequences
1034
struct llama_ubatch {

0 commit comments

Comments
 (0)