@@ -231,9 +231,38 @@ extern "C" {
231231
232232 typedef bool (*llama_progress_callback)(float progress, void * user_data);
233233
234- struct llama_batch ;
235-
236- struct llama_batch_token_info {
234+ // Input data for llama_decode
235+ //
236+ // WARN: This struct is DEPRECATED and will be removed in the future, use llama_batch_ext instead
237+ //
238+ // A llama_batch object can contain input about one or many sequences
239+ // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
240+ //
241+ // - token : the token ids of the input (used when embd is NULL)
242+ // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
243+ // - pos : the positions of the respective token in the sequence
244+ // (if set to NULL, the token position will be tracked automatically by llama_decode)
245+ // - seq_id : the sequence to which the respective token belongs
246+ // (if set to NULL, the sequence ID will be assumed to be 0)
247+ // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
248+ // (if set to NULL, only the logits for last token will be returned)
249+ //
250+ typedef struct llama_batch {
251+ int32_t n_tokens;
252+
253+ llama_token * token;
254+ float * embd;
255+ llama_pos * pos;
256+ int32_t * n_seq_id;
257+ llama_seq_id ** seq_id;
258+ int8_t * logits; // TODO: rename this to "output"
259+ } llama_batch;
260+
261+ // Input data for llama_decode / llama_encode
262+ // It can contain text tokens and embeddings for one or many sequences
263+ struct llama_batch_ext ;
264+
265+ struct llama_batch_ext_token_info {
237266 llama_token token;
238267 llama_pos pos;
239268 int32_t n_seq_id;
@@ -815,9 +844,9 @@ extern "C" {
815844 //
816845 // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
817846 //
818- LLAMA_API struct llama_batch * llama_batch_get_one (
847+ DEPRECATED ( LLAMA_API struct llama_batch llama_batch_get_one (
819848 llama_token * tokens,
820- int32_t n_tokens);
849+ int32_t n_tokens), "use llama_batch_ext API instead") ;
821850
822851 // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
823852 // Each token can be assigned up to n_seq_max sequence ids
@@ -826,39 +855,56 @@ extern "C" {
826855 // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
827856 // The rest of the llama_batch members are allocated with size n_tokens
828857 // All members are left uninitialized
829- // LLAMA_API struct llama_batch llama_batch_init(
830- // int32_t n_tokens,
831- // int32_t embd,
832- // int32_t n_seq_max);
858+ DEPRECATED (LLAMA_API struct llama_batch llama_batch_init (
859+ int32_t n_tokens,
860+ int32_t embd,
861+ int32_t n_seq_max), "use llama_batch_ext API instead");
862+
863+ // Frees a batch of tokens allocated with llama_batch_init()
864+ DEPRECATED (LLAMA_API void llama_batch_free (struct llama_batch batch),
865+ "use llama_batch_ext API instead");
833866
834867 // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
835868 // Each token can be assigned up to n_seq_max sequence ids
836- // The batch has to be freed with llama_batch_free ()
837- LLAMA_API struct llama_batch * llama_batch_init (
869+ // The batch has to be freed with llama_batch_ext_free ()
870+ LLAMA_API struct llama_batch_ext * llama_batch_ext_init (
838871 int32_t n_tokens,
839872 int32_t n_seq_max);
840873
874+ // Same with llama_batch_init, but initializes the batch with the provided text tokens
875+ // First token will be at position pos0
876+ // The sequence ID will be fixed to seq_id
877+ // The batch has to be freed with llama_batch_ext_free()
878+ LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_text (
879+ llama_token * tokens,
880+ int32_t n_tokens,
881+ int32_t pos0,
882+ int32_t seq_id);
883+
841884 // Same with llama_batch_init, but initializes the batch with the provided raw embeddings
842- LLAMA_API struct llama_batch * llama_batch_init_from_embd (
885+ // First token will be at position pos0
886+ // The sequence ID will be fixed to seq_id
887+ // The batch has to be freed with llama_batch_ext_free()
888+ LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd (
843889 float * embd,
844890 size_t n_embd,
845891 int32_t pos0,
846892 int32_t seq_id);
847893
848894 // Get the number of tokens in the batch
849- LLAMA_API int32_t llama_batch_get_n_tokens (const struct llama_batch * batch);
895+ LLAMA_API int32_t llama_batch_ext_get_n_tokens (const struct llama_batch_ext * batch);
850896
851- LLAMA_API struct llama_batch_token_info llama_batch_get_token_info (
852- struct llama_batch * batch,
897+ LLAMA_API struct llama_batch_ext_token_info llama_batch_ext_get_token_info (
898+ struct llama_batch_ext * batch,
853899 int32_t i);
854900
855901 // Add text tokens to the batch
856902 // Return values:
857903 // 0 : success
858904 // -1 : not enough space in the batch
859905 // -2 : embd is already set, cannot add text tokens
860- LLAMA_API int32_t llama_batch_add_text_token (
861- struct llama_batch * batch,
906+ LLAMA_API int32_t llama_batch_ext_add_text_token (
907+ struct llama_batch_ext * batch,
862908 llama_token token,
863909 llama_pos pos,
864910 const llama_seq_id * seq_ids,
@@ -868,43 +914,50 @@ extern "C" {
868914 // Set logits for the token in the ith sequence
869915 // If pos == -1, logits will be set for the all tokens
870916 // Returns -1 if the token is not in the batch
871- LLAMA_API int32_t llama_batch_set_logits (
872- struct llama_batch * batch,
917+ LLAMA_API int32_t llama_batch_ext_set_logits (
918+ struct llama_batch_ext * batch,
873919 llama_pos pos,
874920 llama_seq_id seq_id);
875921
876922 // Set logits for the last added token
877923 // Returns -1 if there is no tokens in the batch
878- LLAMA_API int32_t llama_batch_set_logits_last (struct llama_batch * batch);
924+ LLAMA_API int32_t llama_batch_ext_set_logits_last (struct llama_batch_ext * batch);
879925
880926 // Get a "view" from a number of tokens offset
881927 // Return returned batch must be freed with llama_batch_free()
882- LLAMA_API struct llama_batch * llama_batch_get_view (
883- struct llama_batch * batch,
928+ LLAMA_API struct llama_batch_ext * llama_batch_ext_get_view (
929+ struct llama_batch_ext * batch,
884930 int32_t offset,
885931 int32_t n_tokens);
886932
887933 // Remove everything from the batch
888- LLAMA_API void llama_batch_clear (struct llama_batch * batch);
934+ LLAMA_API void llama_batch_ext_clear (struct llama_batch_ext * batch);
889935
890- // Frees a batch of tokens allocated with llama_batch_init()
891- LLAMA_API void llama_batch_free (struct llama_batch * batch);
936+ // Frees a batch of tokens allocated with llama_batch_ext_init()
937+ // If this is a view, the original batch is not freed
938+ LLAMA_API void llama_batch_ext_free (struct llama_batch_ext * batch);
892939
893940 // Processes a batch of tokens with the ecoder part of the encoder-decoder model.
894941 // Stores the encoder output internally for later use by the decoder cross-attention layers.
895942 // 0 - success
896943 // < 0 - error. the KV cache state is restored to the state before this call
897- LLAMA_API int32_t llama_encode (
944+ DEPRECATED (LLAMA_API int32_t llama_encode (
945+ struct llama_context * ctx,
946+ struct llama_batch batch), "use llama_batch_ext API instead");
947+ LLAMA_API int32_t llama_text_encode (
898948 struct llama_context * ctx,
899- struct llama_batch * batch);
949+ struct llama_batch_ext * batch);
900950
901951 // Positive return values does not mean a fatal error, but rather a warning.
902952 // 0 - success
903953 // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
904954 // < 0 - error. the KV cache state is restored to the state before this call
905- LLAMA_API int32_t llama_decode (
955+ DEPRECATED (LLAMA_API int32_t llama_decode (
956+ struct llama_context * ctx,
957+ struct llama_batch batch), "use llama_batch_ext API instead");
958+ LLAMA_API int32_t llama_text_decode (
906959 struct llama_context * ctx,
907- struct llama_batch * batch);
960+ struct llama_batch_ext * batch);
908961
909962 // Set the number of threads used for decoding
910963 // n_threads is the number of threads used for generation (single token)
0 commit comments