@@ -47,30 +47,18 @@ def __init__(self, **kwargs):
4747 # Params file.
4848 params_path = kwargs .get ("params" , None )
4949
50- self .llm_config = kwargs .get ("llm_config" , None )
50+ self .llm_config = kwargs .get ("llm_config" )
51+ assert self .llm_config is not None , "llm_config must be provided"
5152
52- # Set all parameters from llm_config if available, otherwise use kwargs as fallback
53- if self .llm_config :
54- self .use_kv_cache = self .llm_config .model .use_kv_cache
55- self .use_sdpa_with_kv_cache_op = self .llm_config .model .use_sdpa_with_kv_cache
56- self .generate_full_logits = self .llm_config .debug .generate_full_logits
57- self .enable_dynamic_shape = self .llm_config .model .enable_dynamic_shape
58- self .input_prune_map_path = self .llm_config .model .input_prune_map
59- self .output_prune_map_path = self .llm_config .model .output_prune_map
60- self .max_seq_len = self .llm_config .export .max_seq_length
61- self .max_context_len = self .llm_config .export .max_context_length
62- self .verbose = self .llm_config .debug .verbose
63- else :
64- # Fallback to kwargs for backward compatibility
65- self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
66- self .use_sdpa_with_kv_cache_op = kwargs .get ("use_sdpa_with_kv_cache" , False )
67- self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
68- self .enable_dynamic_shape = kwargs .get ("enable_dynamic_shape" , False )
69- self .input_prune_map_path = kwargs .get ("input_prune_map_path" , None )
70- self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
71- self .max_seq_len = kwargs .get ("max_seq_len" , 128 )
72- self .max_context_len = kwargs .get ("max_context_len" , 128 )
73- self .verbose = kwargs .get ("verbose" , False )
53+ self .use_kv_cache = self .llm_config .model .use_kv_cache
54+ self .use_sdpa_with_kv_cache_op = self .llm_config .model .use_sdpa_with_kv_cache
55+ self .generate_full_logits = self .llm_config .debug .generate_full_logits
56+ self .enable_dynamic_shape = self .llm_config .model .enable_dynamic_shape
57+ self .input_prune_map_path = self .llm_config .model .input_prune_map
58+ self .output_prune_map_path = self .llm_config .model .output_prune_map
59+ self .max_seq_len = self .llm_config .export .max_seq_length
60+ self .max_context_len = self .llm_config .export .max_context_length
61+ self .verbose = self .llm_config .debug .verbose
7462
7563 assert (
7664 self .max_context_len >= self .max_seq_len
@@ -173,7 +161,7 @@ def __init__(self, **kwargs):
173161
174162 if model_args .use_scaled_rope :
175163 # Older models don't have use_scaled_rope configuration
176- model_name = str (self .llm_config .base .model_class ) if self . llm_config else "llama3"
164+ model_name = str (self .llm_config .base .model_class )
177165 assert model_name not in ["llama2" , "stories110m" ]
178166
179167 # Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
@@ -212,7 +200,7 @@ def __init__(self, **kwargs):
212200 self .model_ = Int8DynActInt4WeightQuantizer ()._convert_for_runtime (
213201 self .model_
214202 )
215- elif self .llm_config and self . llm_config .quantization .use_spin_quant :
203+ elif self .llm_config .quantization .use_spin_quant :
216204 print ("Using SPIN quantization." )
217205 self ._transform_for_pre_quantization (checkpoint , model_args )
218206
@@ -221,7 +209,7 @@ def __init__(self, **kwargs):
221209 )
222210
223211 sanitize_checkpoint_from_pre_quantization (checkpoint )
224- elif self .llm_config and self . llm_config .quantization .use_qat :
212+ elif self .llm_config .quantization .use_qat :
225213 print ("Using QAT quantization." )
226214 self ._transform_for_pre_quantization (checkpoint , model_args )
227215 if self .llm_config .base .use_lora :
@@ -243,7 +231,7 @@ def __init__(self, **kwargs):
243231
244232 sanitize_checkpoint_from_pre_quantization (checkpoint )
245233
246- if self .llm_config and self . llm_config .model .use_attention_sink :
234+ if self .llm_config .model .use_attention_sink :
247235 from .source_transformation .attention_sink import enable_attention_sink
248236
249237 attention_sink_params = self .llm_config .model .use_attention_sink .split ("," )
@@ -343,7 +331,7 @@ def get_example_inputs_kvcache_sdpa(self):
343331 )
344332
345333 def _transform_for_pre_quantization (self , checkpoint , model_args ):
346- assert self .llm_config and self . llm_config .base .preq_mode , "preq_mode must be specified"
334+ assert self .llm_config .base .preq_mode , "preq_mode must be specified"
347335 assert self .llm_config .base .preq_mode in [
348336 "8da4w" ,
349337 "8da4w_output_8da8w" ,
0 commit comments