diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 71d94ebd8..c4a54b395 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -90,6 +90,8 @@ def __init__( yarn_orig_ctx: int = 0, logits_all: bool = False, embedding: bool = False, + n_seq_max: Optional[int] = None, + kv_unified: Optional[bool] = None, offload_kqv: bool = True, flash_attn: bool = False, op_offload: Optional[bool] = None, @@ -172,6 +174,8 @@ def __init__( yarn_orig_ctx: YaRN original context size logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs. embedding: Embedding mode only. + n_seq_max: Maximum number of sequences in KV cache + kv_unified: Use unified KV cache across sequences offload_kqv: Offload K, Q, V to GPU. flash_attn: Use flash attention. op_offload: offload host tensor operations to device @@ -343,6 +347,14 @@ def __init__( self.context_params.offload_kqv = offload_kqv self.context_params.flash_attn = flash_attn + # this allows for batch embedding many sequences + if n_seq_max is not None: + self.context_params.n_seq_max = n_seq_max + if kv_unified is not None: + self.context_params.kv_unified = kv_unified + elif embedding and n_seq_max is None: + self.context_params.kv_unified = True + if op_offload is not None: self.context_params.op_offload = op_offload