Skip to content

Commit 6aa035b

Browse files
committed
Sync llama : rework embeddings logic
1 parent 903f008 commit 6aa035b

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

llama_cpp/llama_cpp.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -552,27 +552,31 @@ class llama_token_data_array(ctypes.Structure):
552552
)
553553

554554

555-
# // Input data for llama_decode
555+
# // Input data for llama_encode/llama_decode
556556
# // A llama_batch object can contain input about one or many sequences
557557
# // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
558558
# //
559559
# // - token : the token ids of the input (used when embd is NULL)
560560
# // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
561561
# // - pos : the positions of the respective token in the sequence
562-
# // (if set to NULL, the token position will be tracked automatically by llama_decode)
562+
# // (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode)
563563
# // - seq_id : the sequence to which the respective token belongs
564564
# // (if set to NULL, the sequence ID will be assumed to be 0)
565565
# // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
566-
# // (if set to NULL, only the logits for last token will be returned)
566+
# // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
567+
# // (if set to NULL:
568+
# // - if embeddings: all tokens are output
569+
# // - if not: only the last token is output
570+
# // )
567571
# //
568572
# typedef struct llama_batch {
569573
# int32_t n_tokens;
570574

571575
# llama_token * token;
572576
# float * embd;
573577
# llama_pos * pos;
574-
# int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
575-
# llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
578+
# int32_t * n_seq_id;
579+
# llama_seq_id ** seq_id;
576580
# int8_t * logits; // TODO: rename this to "output"
577581
# } llama_batch;
578582
class llama_batch(ctypes.Structure):
@@ -2532,12 +2536,12 @@ def llama_n_threads_batch(ctx: llama_context_p, /) -> int:
25322536

25332537

25342538
# // Set whether the model is in embeddings mode or not
2535-
# // If true, embeddings will be returned but logits will not
25362539
# LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
25372540
@ctypes_function("llama_set_embeddings", [llama_context_p_ctypes, ctypes.c_bool], None)
25382541
def llama_set_embeddings(ctx: llama_context_p, embeddings: bool, /):
2539-
"""Set whether the model is in embeddings model or not
2540-
If true, embeddings will be returned but logits will not"""
2542+
"""
2543+
Set whether the model is in embeddings model or not
2544+
"""
25412545
...
25422546

25432547

0 commit comments

Comments
 (0)