Skip to content

Commit 17d3658

Browse files
committed
move to llama_batch_ext
1 parent f2e59a8 commit 17d3658

File tree

8 files changed

+223
-118
lines changed

8 files changed

+223
-118
lines changed

common/common.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,20 +1610,29 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons
16101610
// Batch utils
16111611
//
16121612

1613-
void common_batch_clear(struct llama_batch * batch) {
1614-
llama_batch_clear(batch);
1613+
// DEPRECATED
1614+
void common_batch_clear(struct llama_batch & batch) {
1615+
batch.n_tokens = 0;
16151616
}
16161617

1618+
// DEPRECATED
16171619
void common_batch_add(
1618-
struct llama_batch * batch,
1620+
struct llama_batch & batch,
16191621
llama_token id,
16201622
llama_pos pos,
16211623
const std::vector<llama_seq_id> & seq_ids,
16221624
bool logits) {
1623-
int32_t res = llama_batch_add_text_token(batch, id, pos, seq_ids.data(), seq_ids.size(), logits);
1624-
if (res == -1) {
1625-
LOG_ERR("%s: llama_batch size exceeded\n", __func__);
1625+
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
1626+
1627+
batch.token [batch.n_tokens] = id;
1628+
batch.pos [batch.n_tokens] = pos;
1629+
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
1630+
for (size_t i = 0; i < seq_ids.size(); ++i) {
1631+
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
16261632
}
1633+
batch.logits [batch.n_tokens] = logits;
1634+
1635+
batch.n_tokens++;
16271636
}
16281637

16291638
//

common/common.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,10 +554,12 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap
554554
// Batch utils
555555
//
556556

557-
void common_batch_clear(struct llama_batch * batch);
557+
// DEPRECATED
558+
void common_batch_clear(struct llama_batch & batch);
558559

560+
// DEPRECATED
559561
void common_batch_add(
560-
struct llama_batch * batch,
562+
struct llama_batch & batch,
561563
llama_token id,
562564
llama_pos pos,
563565
const std::vector<llama_seq_id> & seq_ids,

common/speculative.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct common_speculative {
1313
struct llama_context * ctx;
1414
struct common_sampler * smpl;
1515

16-
llama_batch * batch;
16+
llama_batch batch;
1717
llama_tokens prompt;
1818
};
1919

@@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init(
2222
auto * result = new common_speculative {
2323
/* .ctx = */ ctx_dft,
2424
/* .smpl = */ nullptr,
25-
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 1),
25+
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
2626
/* .prompt = */ {},
2727
};
2828

@@ -215,7 +215,7 @@ llama_tokens common_speculative_gen_draft(
215215
}
216216

217217
// we should rarely end-up here during normal decoding
218-
if (llama_batch_get_n_tokens(batch) > 0) {
218+
if (batch.n_tokens > 0) {
219219
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
220220

221221
llama_decode(ctx, batch);

include/llama-cpp.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ struct llama_adapter_lora_deleter {
2424
void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
2525
};
2626

27-
struct llama_batch_deleter {
28-
void operator()(llama_batch * batch) { llama_batch_free(batch); }
27+
struct llama_batch_ext_deleter {
28+
void operator()(llama_batch_ext * batch) { llama_batch_ext_free(batch); }
2929
};
3030

3131
typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
3232
typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr;
3333
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr;
3434
typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr;
35-
typedef std::unique_ptr<llama_batch, llama_batch_deleter> llama_batch_ptr;
35+
typedef std::unique_ptr<llama_batch_ext, llama_batch_ext_deleter> llama_batch_ext_ptr;

include/llama.h

Lines changed: 82 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,38 @@ extern "C" {
231231

232232
typedef bool (*llama_progress_callback)(float progress, void * user_data);
233233

234-
struct llama_batch;
235-
236-
struct llama_batch_token_info {
234+
// Input data for llama_decode
235+
//
236+
// WARN: This struct is DEPRECATED and will be removed in the future, use llama_batch_ext instead
237+
//
238+
// A llama_batch object can contain input about one or many sequences
239+
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
240+
//
241+
// - token : the token ids of the input (used when embd is NULL)
242+
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
243+
// - pos : the positions of the respective token in the sequence
244+
// (if set to NULL, the token position will be tracked automatically by llama_decode)
245+
// - seq_id : the sequence to which the respective token belongs
246+
// (if set to NULL, the sequence ID will be assumed to be 0)
247+
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
248+
// (if set to NULL, only the logits for last token will be returned)
249+
//
250+
typedef struct llama_batch {
251+
int32_t n_tokens;
252+
253+
llama_token * token;
254+
float * embd;
255+
llama_pos * pos;
256+
int32_t * n_seq_id;
257+
llama_seq_id ** seq_id;
258+
int8_t * logits; // TODO: rename this to "output"
259+
} llama_batch;
260+
261+
// Input data for llama_decode / llama_encode
262+
// It can contain text tokens and embeddings for one or many sequences
263+
struct llama_batch_ext;
264+
265+
struct llama_batch_ext_token_info {
237266
llama_token token;
238267
llama_pos pos;
239268
int32_t n_seq_id;
@@ -815,9 +844,9 @@ extern "C" {
815844
//
816845
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
817846
//
818-
LLAMA_API struct llama_batch * llama_batch_get_one(
847+
DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one(
819848
llama_token * tokens,
820-
int32_t n_tokens);
849+
int32_t n_tokens), "use llama_batch_ext API instead");
821850

822851
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
823852
// Each token can be assigned up to n_seq_max sequence ids
@@ -826,39 +855,56 @@ extern "C" {
826855
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
827856
// The rest of the llama_batch members are allocated with size n_tokens
828857
// All members are left uninitialized
829-
// LLAMA_API struct llama_batch llama_batch_init(
830-
// int32_t n_tokens,
831-
// int32_t embd,
832-
// int32_t n_seq_max);
858+
DEPRECATED(LLAMA_API struct llama_batch llama_batch_init(
859+
int32_t n_tokens,
860+
int32_t embd,
861+
int32_t n_seq_max), "use llama_batch_ext API instead");
862+
863+
// Frees a batch of tokens allocated with llama_batch_init()
864+
DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch),
865+
"use llama_batch_ext API instead");
833866

834867
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
835868
// Each token can be assigned up to n_seq_max sequence ids
836-
// The batch has to be freed with llama_batch_free()
837-
LLAMA_API struct llama_batch * llama_batch_init(
869+
// The batch has to be freed with llama_batch_ext_free()
870+
LLAMA_API struct llama_batch_ext * llama_batch_ext_init(
838871
int32_t n_tokens,
839872
int32_t n_seq_max);
840873

874+
// Same with llama_batch_init, but initializes the batch with the provided text tokens
875+
// First token will be at position pos0
876+
// The sequence ID will be fixed to seq_id
877+
// The batch has to be freed with llama_batch_ext_free()
878+
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_text(
879+
llama_token * tokens,
880+
int32_t n_tokens,
881+
int32_t pos0,
882+
int32_t seq_id);
883+
841884
// Same with llama_batch_init, but initializes the batch with the provided raw embeddings
842-
LLAMA_API struct llama_batch * llama_batch_init_from_embd(
885+
// First token will be at position pos0
886+
// The sequence ID will be fixed to seq_id
887+
// The batch has to be freed with llama_batch_ext_free()
888+
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd(
843889
float * embd,
844890
size_t n_embd,
845891
int32_t pos0,
846892
int32_t seq_id);
847893

848894
// Get the number of tokens in the batch
849-
LLAMA_API int32_t llama_batch_get_n_tokens(const struct llama_batch * batch);
895+
LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch);
850896

851-
LLAMA_API struct llama_batch_token_info llama_batch_get_token_info(
852-
struct llama_batch * batch,
897+
LLAMA_API struct llama_batch_ext_token_info llama_batch_ext_get_token_info(
898+
struct llama_batch_ext * batch,
853899
int32_t i);
854900

855901
// Add text tokens to the batch
856902
// Return values:
857903
// 0 : success
858904
// -1 : not enough space in the batch
859905
// -2 : embd is already set, cannot add text tokens
860-
LLAMA_API int32_t llama_batch_add_text_token(
861-
struct llama_batch * batch,
906+
LLAMA_API int32_t llama_batch_ext_add_text_token(
907+
struct llama_batch_ext * batch,
862908
llama_token token,
863909
llama_pos pos,
864910
const llama_seq_id * seq_ids,
@@ -868,43 +914,50 @@ extern "C" {
868914
// Set logits for the token in the ith sequence
869915
// If pos == -1, logits will be set for the all tokens
870916
// Returns -1 if the token is not in the batch
871-
LLAMA_API int32_t llama_batch_set_logits(
872-
struct llama_batch * batch,
917+
LLAMA_API int32_t llama_batch_ext_set_logits(
918+
struct llama_batch_ext * batch,
873919
llama_pos pos,
874920
llama_seq_id seq_id);
875921

876922
// Set logits for the last added token
877923
// Returns -1 if there is no tokens in the batch
878-
LLAMA_API int32_t llama_batch_set_logits_last(struct llama_batch * batch);
924+
LLAMA_API int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch);
879925

880926
// Get a "view" from a number of tokens offset
881927
// Return returned batch must be freed with llama_batch_free()
882-
LLAMA_API struct llama_batch * llama_batch_get_view(
883-
struct llama_batch * batch,
928+
LLAMA_API struct llama_batch_ext * llama_batch_ext_get_view(
929+
struct llama_batch_ext * batch,
884930
int32_t offset,
885931
int32_t n_tokens);
886932

887933
// Remove everything from the batch
888-
LLAMA_API void llama_batch_clear(struct llama_batch * batch);
934+
LLAMA_API void llama_batch_ext_clear(struct llama_batch_ext * batch);
889935

890-
// Frees a batch of tokens allocated with llama_batch_init()
891-
LLAMA_API void llama_batch_free(struct llama_batch * batch);
936+
// Frees a batch of tokens allocated with llama_batch_ext_init()
937+
// If this is a view, the original batch is not freed
938+
LLAMA_API void llama_batch_ext_free(struct llama_batch_ext * batch);
892939

893940
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
894941
// Stores the encoder output internally for later use by the decoder cross-attention layers.
895942
// 0 - success
896943
// < 0 - error. the KV cache state is restored to the state before this call
897-
LLAMA_API int32_t llama_encode(
944+
DEPRECATED(LLAMA_API int32_t llama_encode(
945+
struct llama_context * ctx,
946+
struct llama_batch batch), "use llama_batch_ext API instead");
947+
LLAMA_API int32_t llama_text_encode(
898948
struct llama_context * ctx,
899-
struct llama_batch * batch);
949+
struct llama_batch_ext * batch);
900950

901951
// Positive return values does not mean a fatal error, but rather a warning.
902952
// 0 - success
903953
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
904954
// < 0 - error. the KV cache state is restored to the state before this call
905-
LLAMA_API int32_t llama_decode(
955+
DEPRECATED(LLAMA_API int32_t llama_decode(
956+
struct llama_context * ctx,
957+
struct llama_batch batch), "use llama_batch_ext API instead");
958+
LLAMA_API int32_t llama_text_decode(
906959
struct llama_context * ctx,
907-
struct llama_batch * batch);
960+
struct llama_batch_ext * batch);
908961

909962
// Set the number of threads used for decoding
910963
// n_threads is the number of threads used for generation (single token)

0 commit comments

Comments
 (0)