@@ -47,15 +47,30 @@ def __init__(self, **kwargs):
4747        # Params file. 
4848        params_path  =  kwargs .get ("params" , None )
4949
50-         self .use_kv_cache  =  kwargs .get ("use_kv_cache" , False )
51-         self .use_sdpa_with_kv_cache_op  =  kwargs .get ("use_sdpa_with_kv_cache" , False )
52-         self .generate_full_logits  =  kwargs .get ("generate_full_logits" , False )
53-         self .enable_dynamic_shape  =  kwargs .get ("enable_dynamic_shape" , False )
54-         self .input_prune_map_path  =  kwargs .get ("input_prune_map_path" , None )
55-         self .output_prune_map_path  =  kwargs .get ("output_prune_map_path" , None )
56-         self .max_seq_len  =  kwargs .get ("max_seq_len" , 128 )
57-         self .max_context_len  =  kwargs .get ("max_context_len" , 128 )
5850        self .llm_config  =  kwargs .get ("llm_config" , None )
51+         
52+         # Set all parameters from llm_config if available, otherwise use kwargs as fallback 
53+         if  self .llm_config :
54+             self .use_kv_cache  =  self .llm_config .model .use_kv_cache 
55+             self .use_sdpa_with_kv_cache_op  =  self .llm_config .model .use_sdpa_with_kv_cache 
56+             self .generate_full_logits  =  self .llm_config .debug .generate_full_logits 
57+             self .enable_dynamic_shape  =  self .llm_config .model .enable_dynamic_shape 
58+             self .input_prune_map_path  =  self .llm_config .model .input_prune_map 
59+             self .output_prune_map_path  =  self .llm_config .model .output_prune_map 
60+             self .max_seq_len  =  self .llm_config .export .max_seq_length 
61+             self .max_context_len  =  self .llm_config .export .max_context_length 
62+             self .verbose  =  self .llm_config .debug .verbose 
63+         else :
64+             # Fallback to kwargs for backward compatibility 
65+             self .use_kv_cache  =  kwargs .get ("use_kv_cache" , False )
66+             self .use_sdpa_with_kv_cache_op  =  kwargs .get ("use_sdpa_with_kv_cache" , False )
67+             self .generate_full_logits  =  kwargs .get ("generate_full_logits" , False )
68+             self .enable_dynamic_shape  =  kwargs .get ("enable_dynamic_shape" , False )
69+             self .input_prune_map_path  =  kwargs .get ("input_prune_map_path" , None )
70+             self .output_prune_map_path  =  kwargs .get ("output_prune_map_path" , None )
71+             self .max_seq_len  =  kwargs .get ("max_seq_len" , 128 )
72+             self .max_context_len  =  kwargs .get ("max_context_len" , 128 )
73+             self .verbose  =  kwargs .get ("verbose" , False )
5974
6075        assert  (
6176            self .max_context_len  >=  self .max_seq_len 
@@ -165,7 +180,7 @@ def __init__(self, **kwargs):
165180            if  model_name  not  in   ["llama3" , "llama3_1" ]:
166181                model_args .rope_scale_factor  =  32 
167182
168-         if  kwargs . get ( " verbose" ,  False ) :
183+         if  self . verbose :
169184            print ("============= weights ================" )
170185            print ("{key} : {weights.numel()} : {weights.size()}" )
171186            for  key , weights  in  checkpoint .items ():
@@ -280,7 +295,7 @@ def __init__(self, **kwargs):
280295                    f"The provided checkpoint is missing the following weights that are expected by the model: { missing_weights }  . Please fix the fqn's in your checkpoint to match." 
281296                )
282297        if  unexpected :
283-             if  kwargs . get ( " verbose" ,  False ) :
298+             if  self . verbose :
284299                print (f"Unexpected keys: { unexpected }  " )
285300
286301        # Prune the input layer if input_prune_map is provided 
0 commit comments