@@ -92,6 +92,7 @@ def __init__(
9292        logits_all : bool  =  False ,
9393        embedding : bool  =  False ,
9494        offload_kqv : bool  =  True ,
95+         flash_attn : bool  =  False ,
9596        # Sampling Params 
9697        last_n_tokens_size : int  =  64 ,
9798        # LoRA Params 
@@ -168,6 +169,7 @@ def __init__(
168169            logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs. 
169170            embedding: Embedding mode only. 
170171            offload_kqv: Offload K, Q, V to GPU. 
172+             flash_attn: Use flash attention. 
171173            last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. 
172174            lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. 
173175            lora_path: Path to a LoRA file to apply to the model. 
@@ -310,6 +312,7 @@ def __init__(
310312        )  # Must be set to True for speculative decoding 
311313        self .context_params .embeddings  =  embedding  # TODO: Rename to embeddings 
312314        self .context_params .offload_kqv  =  offload_kqv 
315+         self .context_params .flash_attn  =  flash_attn 
313316        #  KV cache quantization 
314317        if  type_k  is  not None :
315318            self .context_params .type_k  =  type_k 
@@ -1774,6 +1777,7 @@ def __getstate__(self):
17741777            logits_all = self .context_params .logits_all ,
17751778            embedding = self .context_params .embeddings ,
17761779            offload_kqv = self .context_params .offload_kqv ,
1780+             flash_offload = self .context_params .flash_offload ,
17771781            # Sampling Params 
17781782            last_n_tokens_size = self .last_n_tokens_size ,
17791783            # LoRA Params 
0 commit comments