@@ -47,30 +47,18 @@ def __init__(self, **kwargs):
4747        # Params file. 
4848        params_path  =  kwargs .get ("params" , None )
4949
50-         self .llm_config  =  kwargs .get ("llm_config" , None )
50+         self .llm_config  =  kwargs .get ("llm_config" )
51+         assert  self .llm_config  is  not None , "llm_config must be provided" 
5152
52-         # Set all parameters from llm_config if available, otherwise use kwargs as fallback 
53-         if  self .llm_config :
54-             self .use_kv_cache  =  self .llm_config .model .use_kv_cache 
55-             self .use_sdpa_with_kv_cache_op  =  self .llm_config .model .use_sdpa_with_kv_cache 
56-             self .generate_full_logits  =  self .llm_config .debug .generate_full_logits 
57-             self .enable_dynamic_shape  =  self .llm_config .model .enable_dynamic_shape 
58-             self .input_prune_map_path  =  self .llm_config .model .input_prune_map 
59-             self .output_prune_map_path  =  self .llm_config .model .output_prune_map 
60-             self .max_seq_len  =  self .llm_config .export .max_seq_length 
61-             self .max_context_len  =  self .llm_config .export .max_context_length 
62-             self .verbose  =  self .llm_config .debug .verbose 
63-         else :
64-             # Fallback to kwargs for backward compatibility 
65-             self .use_kv_cache  =  kwargs .get ("use_kv_cache" , False )
66-             self .use_sdpa_with_kv_cache_op  =  kwargs .get ("use_sdpa_with_kv_cache" , False )
67-             self .generate_full_logits  =  kwargs .get ("generate_full_logits" , False )
68-             self .enable_dynamic_shape  =  kwargs .get ("enable_dynamic_shape" , False )
69-             self .input_prune_map_path  =  kwargs .get ("input_prune_map_path" , None )
70-             self .output_prune_map_path  =  kwargs .get ("output_prune_map_path" , None )
71-             self .max_seq_len  =  kwargs .get ("max_seq_len" , 128 )
72-             self .max_context_len  =  kwargs .get ("max_context_len" , 128 )
73-             self .verbose  =  kwargs .get ("verbose" , False )
53+         self .use_kv_cache  =  self .llm_config .model .use_kv_cache 
54+         self .use_sdpa_with_kv_cache_op  =  self .llm_config .model .use_sdpa_with_kv_cache 
55+         self .generate_full_logits  =  self .llm_config .debug .generate_full_logits 
56+         self .enable_dynamic_shape  =  self .llm_config .model .enable_dynamic_shape 
57+         self .input_prune_map_path  =  self .llm_config .model .input_prune_map 
58+         self .output_prune_map_path  =  self .llm_config .model .output_prune_map 
59+         self .max_seq_len  =  self .llm_config .export .max_seq_length 
60+         self .max_context_len  =  self .llm_config .export .max_context_length 
61+         self .verbose  =  self .llm_config .debug .verbose 
7462
7563        assert  (
7664            self .max_context_len  >=  self .max_seq_len 
@@ -173,7 +161,7 @@ def __init__(self, **kwargs):
173161
174162        if  model_args .use_scaled_rope :
175163            # Older models don't have use_scaled_rope configuration 
176-             model_name  =  str (self .llm_config .base .model_class )  if   self . llm_config   else   "llama3" 
164+             model_name  =  str (self .llm_config .base .model_class )
177165            assert  model_name  not  in "llama2" , "stories110m" ]
178166
179167            # Llama3_2 and newer models in ExecuTorch repo should set larger scale factor 
@@ -212,7 +200,7 @@ def __init__(self, **kwargs):
212200            self .model_  =  Int8DynActInt4WeightQuantizer ()._convert_for_runtime (
213201                self .model_ 
214202            )
215-         elif  self .llm_config   and   self . llm_config .quantization .use_spin_quant :
203+         elif  self .llm_config .quantization .use_spin_quant :
216204            print ("Using SPIN quantization." )
217205            self ._transform_for_pre_quantization (checkpoint , model_args )
218206
@@ -221,7 +209,7 @@ def __init__(self, **kwargs):
221209            )
222210
223211            sanitize_checkpoint_from_pre_quantization (checkpoint )
224-         elif  self .llm_config   and   self . llm_config .quantization .use_qat :
212+         elif  self .llm_config .quantization .use_qat :
225213            print ("Using QAT quantization." )
226214            self ._transform_for_pre_quantization (checkpoint , model_args )
227215            if  self .llm_config .base .use_lora :
@@ -243,7 +231,7 @@ def __init__(self, **kwargs):
243231
244232            sanitize_checkpoint_from_pre_quantization (checkpoint )
245233
246-         if  self .llm_config   and   self . llm_config .model .use_attention_sink :
234+         if  self .llm_config .model .use_attention_sink :
247235            from  .source_transformation .attention_sink  import  enable_attention_sink 
248236
249237            attention_sink_params  =  self .llm_config .model .use_attention_sink .split ("," )
@@ -343,7 +331,7 @@ def get_example_inputs_kvcache_sdpa(self):
343331            )
344332
345333    def  _transform_for_pre_quantization (self , checkpoint , model_args ):
346-         assert  self .llm_config   and   self . llm_config .base .preq_mode , "preq_mode must be specified" 
334+         assert  self .llm_config .base .preq_mode , "preq_mode must be specified" 
347335        assert  self .llm_config .base .preq_mode  in  [
348336            "8da4w" ,
349337            "8da4w_output_8da8w" ,
0 commit comments