@@ -230,6 +230,8 @@ def __init__(
230230 tensor_split : Optional [List [float ]] = None ,
231231 rope_freq_base : float = 10000.0 ,
232232 rope_freq_scale : float = 1.0 ,
233+ n_gqa : Optional [int ] = None , # (TEMPORARY) must be 8 for llama2 70b
234+ rms_norm_eps : Optional [float ] = None , # (TEMPORARY)
233235 verbose : bool = True ,
234236 ):
235237 """Load a llama.cpp model from `model_path`.
@@ -291,6 +293,12 @@ def __init__(
291293 self .params .rope_freq_base = rope_freq_base
292294 self .params .rope_freq_scale = rope_freq_scale
293295
296+ if n_gqa is not None :
297+ self .params .n_gqa = n_gqa
298+
299+ if rms_norm_eps is not None :
300+ self .params .rms_norm_eps = rms_norm_eps
301+
294302 self .last_n_tokens_size = last_n_tokens_size
295303 self .n_batch = min (n_ctx , n_batch )
296304
@@ -1530,6 +1538,10 @@ def __getstate__(self):
15301538 lora_base = self .lora_base ,
15311539 lora_path = self .lora_path ,
15321540 tensor_split = self .tensor_split ,
1541+ ### TEMPORARY ###
1542+ n_gqa = self .params .n_gqa ,
1543+ rms_norm_eps = self .params .rms_norm_eps ,
1544+ ### TEMPORARY ###
15331545 ### DEPRECATED ###
15341546 n_parts = self .n_parts ,
15351547 ### DEPRECATED ###
@@ -1539,7 +1551,6 @@ def __setstate__(self, state):
15391551 self .__init__ (
15401552 model_path = state ["model_path" ],
15411553 n_ctx = state ["n_ctx" ],
1542- n_parts = state ["n_parts" ],
15431554 n_gpu_layers = state ["n_gpu_layers" ],
15441555 seed = state ["seed" ],
15451556 f16_kv = state ["f16_kv" ],
@@ -1556,6 +1567,13 @@ def __setstate__(self, state):
15561567 lora_path = state ["lora_path" ],
15571568 tensor_split = state ["tensor_split" ],
15581569 verbose = state ["verbose" ],
1570+ ### TEMPORARY ###
1571+ n_gqa = state ["n_gqa" ],
1572+ rms_norm_eps = state ["rms_norm_eps" ],
1573+ ### TEMPORARY ###
1574+ ### DEPRECATED ###
1575+ n_parts = state ["n_parts" ],
1576+ ### DEPRECATED ###
15591577 )
15601578
15611579 def save_state (self ) -> LlamaState :
0 commit comments