@@ -429,10 +429,12 @@ class llama_batch(ctypes.Structure):
429429 The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
430430
431431 Attributes:
432+ n_tokens (int): number of tokens
432433 token (ctypes.Array[llama_token]): the token ids of the input (used when embd is NULL)
433434 embd (ctypes.Array[ctypes.ctypes.c_float]): token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
434435 pos (ctypes.Array[ctypes.Array[llama_pos]]): the positions of the respective token in the sequence
435436 seq_id (ctypes.Array[ctypes.Array[llama_seq_id]]): the sequence to which the respective token belongs
437+ logits (ctypes.Array[ctypes.ctypes.c_int8]): if zero, the logits for the respective token will not be output
436438 """
437439
438440 _fields_ = [
@@ -547,6 +549,7 @@ class llama_model_params(ctypes.Structure):
547549# uint32_t seed; // RNG seed, -1 for random
548550# uint32_t n_ctx; // text context, 0 = from model
549551# uint32_t n_batch; // prompt processing maximum batch size
552+ # uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models)
550553# uint32_t n_threads; // number of threads to use for generation
551554# uint32_t n_threads_batch; // number of threads to use for batch processing
552555
@@ -588,6 +591,7 @@ class llama_context_params(ctypes.Structure):
588591 seed (int): RNG seed, -1 for random
589592 n_ctx (int): text context, 0 = from model
590593 n_batch (int): prompt processing maximum batch size
594+ n_parallel (int): number of parallel sequences (i.e. distinct states for recurrent models)
591595 n_threads (int): number of threads to use for generation
592596 n_threads_batch (int): number of threads to use for batch processing
593597 rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
@@ -615,6 +619,7 @@ class llama_context_params(ctypes.Structure):
615619 ("seed" , ctypes .c_uint32 ),
616620 ("n_ctx" , ctypes .c_uint32 ),
617621 ("n_batch" , ctypes .c_uint32 ),
622+ ("n_parallel" , ctypes .c_uint32 ),
618623 ("n_threads" , ctypes .c_uint32 ),
619624 ("n_threads_batch" , ctypes .c_uint32 ),
620625 ("rope_scaling_type" , ctypes .c_int ),
@@ -1322,7 +1327,7 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
13221327# // seq_id < 0 : match any sequence
13231328# // p0 < 0 : [0, p1]
13241329# // p1 < 0 : [p0, inf)
1325- # LLAMA_API void llama_kv_cache_seq_rm(
1330+ # LLAMA_API bool llama_kv_cache_seq_rm(
13261331# struct llama_context * ctx,
13271332# llama_seq_id seq_id,
13281333# llama_pos p0,
@@ -1335,15 +1340,15 @@ def llama_kv_cache_clear(ctx: llama_context_p, /):
13351340 llama_pos ,
13361341 llama_pos ,
13371342 ],
1338- None ,
1343+ ctypes . c_bool ,
13391344)
13401345def llama_kv_cache_seq_rm (
13411346 ctx : llama_context_p ,
13421347 seq_id : Union [llama_seq_id , int ],
13431348 p0 : Union [llama_pos , int ],
13441349 p1 : Union [llama_pos , int ],
13451350 / ,
1346- ):
1351+ ) -> bool :
13471352 """Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
13481353 seq_id < 0 : match any sequence
13491354 p0 < 0 : [0, p1]
@@ -1754,7 +1759,10 @@ def llama_get_logits(ctx: llama_context_p, /) -> CtypesArray[ctypes.c_float]:
17541759 The logits for the last token are stored in the last row
17551760 Logits for which llama_batch.logits[i] == 0 are undefined
17561761 Rows: n_tokens provided with llama_batch
1757- Cols: n_vocab"""
1762+ Cols: n_vocab
1763+
1764+ Returns:
1765+ Pointer to the logits buffer of shape (n_tokens, n_vocab)"""
17581766 ...
17591767
17601768
0 commit comments