@@ -47,15 +47,30 @@ def __init__(self, **kwargs):
4747 # Params file.
4848 params_path = kwargs .get ("params" , None )
4949
50- self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
51- self .use_sdpa_with_kv_cache_op = kwargs .get ("use_sdpa_with_kv_cache" , False )
52- self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
53- self .enable_dynamic_shape = kwargs .get ("enable_dynamic_shape" , False )
54- self .input_prune_map_path = kwargs .get ("input_prune_map_path" , None )
55- self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
56- self .max_seq_len = kwargs .get ("max_seq_len" , 128 )
57- self .max_context_len = kwargs .get ("max_context_len" , 128 )
5850 self .llm_config = kwargs .get ("llm_config" , None )
51+
52+ # Set all parameters from llm_config if available, otherwise use kwargs as fallback
53+ if self .llm_config :
54+ self .use_kv_cache = self .llm_config .model .use_kv_cache
55+ self .use_sdpa_with_kv_cache_op = self .llm_config .model .use_sdpa_with_kv_cache
56+ self .generate_full_logits = self .llm_config .debug .generate_full_logits
57+ self .enable_dynamic_shape = self .llm_config .model .enable_dynamic_shape
58+ self .input_prune_map_path = self .llm_config .model .input_prune_map
59+ self .output_prune_map_path = self .llm_config .model .output_prune_map
60+ self .max_seq_len = self .llm_config .export .max_seq_length
61+ self .max_context_len = self .llm_config .export .max_context_length
62+ self .verbose = self .llm_config .debug .verbose
63+ else :
64+ # Fallback to kwargs for backward compatibility
65+ self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
66+ self .use_sdpa_with_kv_cache_op = kwargs .get ("use_sdpa_with_kv_cache" , False )
67+ self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
68+ self .enable_dynamic_shape = kwargs .get ("enable_dynamic_shape" , False )
69+ self .input_prune_map_path = kwargs .get ("input_prune_map_path" , None )
70+ self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
71+ self .max_seq_len = kwargs .get ("max_seq_len" , 128 )
72+ self .max_context_len = kwargs .get ("max_context_len" , 128 )
73+ self .verbose = kwargs .get ("verbose" , False )
5974
6075 assert (
6176 self .max_context_len >= self .max_seq_len
@@ -165,7 +180,7 @@ def __init__(self, **kwargs):
165180 if model_name not in ["llama3" , "llama3_1" ]:
166181 model_args .rope_scale_factor = 32
167182
168- if kwargs . get ( " verbose" , False ) :
183+ if self . verbose :
169184 print ("============= weights ================" )
170185 print ("{key} : {weights.numel()} : {weights.size()}" )
171186 for key , weights in checkpoint .items ():
@@ -280,7 +295,7 @@ def __init__(self, **kwargs):
280295 f"The provided checkpoint is missing the following weights that are expected by the model: { missing_weights } . Please fix the fqn's in your checkpoint to match."
281296 )
282297 if unexpected :
283- if kwargs . get ( " verbose" , False ) :
298+ if self . verbose :
284299 print (f"Unexpected keys: { unexpected } " )
285300
286301 # Prune the input layer if input_prune_map is provided
0 commit comments