@@ -105,6 +105,9 @@ def __init__(
105105 draft_model : Optional [LlamaDraftModel ] = None ,
106106 # Tokenizer Override
107107 tokenizer : Optional [BaseLlamaTokenizer ] = None ,
108+ # KV cache quantization
109+ type_k : Optional [int ] = None ,
110+ type_v : Optional [int ] = None ,
108111 # Misc
109112 verbose : bool = True ,
110113 # Extra Params
@@ -172,6 +175,8 @@ def __init__(
172175 draft_model: Optional draft model to use for speculative decoding.
173176 tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
174177 verbose: Print verbose output to stderr.
178+ type_k: KV cache data type for K (default: f16)
179+ type_v: KV cache data type for V (default: f16)
175180
176181 Raises:
177182 ValueError: If the model path does not exist.
@@ -298,7 +303,11 @@ def __init__(
298303 ) # Must be set to True for speculative decoding
299304 self .context_params .embeddings = embedding # TODO: Rename to embeddings
300305 self .context_params .offload_kqv = offload_kqv
301-
306+ # KV cache quantization
307+ if type_k is not None :
308+ self .context_params .type_k = type_k
309+ if type_v is not None :
310+ self .context_params .type_v = type_v
302311 # Sampling Params
303312 self .last_n_tokens_size = last_n_tokens_size
304313
@@ -1724,6 +1733,7 @@ def __getstate__(self):
17241733 n_threads = self .context_params .n_threads ,
17251734 n_threads_batch = self .context_params .n_threads_batch ,
17261735 rope_scaling_type = self .context_params .rope_scaling_type ,
1736+ pooling_type = self .context_params .pooling_type ,
17271737 rope_freq_base = self .context_params .rope_freq_base ,
17281738 rope_freq_scale = self .context_params .rope_freq_scale ,
17291739 yarn_ext_factor = self .context_params .yarn_ext_factor ,
@@ -1733,6 +1743,7 @@ def __getstate__(self):
17331743 yarn_orig_ctx = self .context_params .yarn_orig_ctx ,
17341744 logits_all = self .context_params .logits_all ,
17351745 embedding = self .context_params .embeddings ,
1746+ offload_kqv = self .context_params .offload_kqv ,
17361747 # Sampling Params
17371748 last_n_tokens_size = self .last_n_tokens_size ,
17381749 # LoRA Params
@@ -1744,51 +1755,17 @@ def __getstate__(self):
17441755 # Chat Format Params
17451756 chat_format = self .chat_format ,
17461757 chat_handler = self .chat_handler ,
1758+ # Speculative Decidng
1759+ draft_model = self .draft_model ,
1760+ # KV cache quantization
1761+ type_k = self .context_params .type_k ,
1762+ type_v = self .context_params .type_v ,
17471763 # Misc
17481764 verbose = self .verbose ,
17491765 )
17501766
17511767 def __setstate__ (self , state ):
1752- self .__init__ (
1753- model_path = state ["model_path" ],
1754- # Model Params
1755- n_gpu_layers = state ["n_gpu_layers" ],
1756- split_mode = state ["split_mode" ],
1757- main_gpu = state ["main_gpu" ],
1758- tensor_split = state ["tensor_split" ],
1759- vocab_only = state ["vocab_only" ],
1760- use_mmap = state ["use_mmap" ],
1761- use_mlock = state ["use_mlock" ],
1762- kv_overrides = state ["kv_overrides" ],
1763- # Context Params
1764- seed = state ["seed" ],
1765- n_ctx = state ["n_ctx" ],
1766- n_batch = state ["n_batch" ],
1767- n_threads = state ["n_threads" ],
1768- n_threads_batch = state ["n_threads_batch" ],
1769- rope_freq_base = state ["rope_freq_base" ],
1770- rope_freq_scale = state ["rope_freq_scale" ],
1771- rope_scaling_type = state ["rope_scaling_type" ],
1772- yarn_ext_factor = state ["yarn_ext_factor" ],
1773- yarn_attn_factor = state ["yarn_attn_factor" ],
1774- yarn_beta_fast = state ["yarn_beta_fast" ],
1775- yarn_beta_slow = state ["yarn_beta_slow" ],
1776- yarn_orig_ctx = state ["yarn_orig_ctx" ],
1777- logits_all = state ["logits_all" ],
1778- embedding = state ["embedding" ],
1779- # Sampling Params
1780- last_n_tokens_size = state ["last_n_tokens_size" ],
1781- # LoRA Params
1782- lora_base = state ["lora_base" ],
1783- lora_path = state ["lora_path" ],
1784- # Backend Params
1785- numa = state ["numa" ],
1786- # Chat Format Params
1787- chat_format = state ["chat_format" ],
1788- chat_handler = state ["chat_handler" ],
1789- # Misc
1790- verbose = state ["verbose" ],
1791- )
1768+ self .__init__ (** state )
17921769
17931770 def save_state (self ) -> LlamaState :
17941771 assert self ._ctx .ctx is not None
0 commit comments