@@ -55,7 +55,7 @@ def __init__(self, **kwargs):
5555 self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
5656 self .max_seq_len = kwargs .get ("max_seq_len" , 128 )
5757 self .max_context_len = kwargs .get ("max_context_len" , 128 )
58- self .args = kwargs .get ("args " , None )
58+ self .llm_config = kwargs .get ("llm_config " , None )
5959
6060 assert (
6161 self .max_context_len >= self .max_seq_len
@@ -158,10 +158,11 @@ def __init__(self, **kwargs):
158158
159159 if model_args .use_scaled_rope :
160160 # Older models don't have use_scaled_rope configuration
161- assert self .args .model not in ["llama2" , "stories110m" ]
161+ model_name = str (self .llm_config .base .model_class ) if self .llm_config else "llama3"
162+ assert model_name not in ["llama2" , "stories110m" ]
162163
163164 # Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
164- if self . args . model not in ["llama3" , "llama3_1" ]:
165+ if model_name not in ["llama3" , "llama3_1" ]:
165166 model_args .rope_scale_factor = 32
166167
167168 if kwargs .get ("verbose" , False ):
@@ -196,7 +197,7 @@ def __init__(self, **kwargs):
196197 self .model_ = Int8DynActInt4WeightQuantizer ()._convert_for_runtime (
197198 self .model_
198199 )
199- elif hasattr ( self .args , "use_spin_quant" ) and self .args .use_spin_quant :
200+ elif self .llm_config and self .llm_config . quantization .use_spin_quant :
200201 print ("Using SPIN quantization." )
201202 self ._transform_for_pre_quantization (checkpoint , model_args )
202203
@@ -205,19 +206,20 @@ def __init__(self, **kwargs):
205206 )
206207
207208 sanitize_checkpoint_from_pre_quantization (checkpoint )
208- elif hasattr ( self .args , "use_qat" ) and self .args .use_qat :
209+ elif self .llm_config and self .llm_config . quantization .use_qat :
209210 print ("Using QAT quantization." )
210211 self ._transform_for_pre_quantization (checkpoint , model_args )
211- if hasattr (self .args , "use_lora" ) and self .args .use_lora :
212- assert model_args .lora_args ["rank" ] == self .args .use_lora
212+ if self .llm_config .base .use_lora :
213+ lora_rank = self .llm_config .base .use_lora
214+ assert model_args .lora_args ["rank" ] == lora_rank
213215 from .source_transformation .lora import (
214216 transform_linear_for_lora_after_quantization ,
215217 )
216218
217219 self .model_ = transform_linear_for_lora_after_quantization (
218220 self .model_ ,
219221 checkpoint ,
220- self . args . use_lora ,
222+ lora_rank ,
221223 )
222224
223225 from .source_transformation .pre_quantization import (
@@ -226,16 +228,16 @@ def __init__(self, **kwargs):
226228
227229 sanitize_checkpoint_from_pre_quantization (checkpoint )
228230
229- if hasattr ( self .args , "use_attention_sink" ) and self .args .use_attention_sink :
231+ if self .llm_config and self .llm_config . model .use_attention_sink :
230232 from .source_transformation .attention_sink import enable_attention_sink
231233
232- attention_sink_params = self .args .use_attention_sink .split ("," )
234+ attention_sink_params = self .llm_config . model .use_attention_sink .split ("," )
233235 assert len (attention_sink_params ) == 3
234236 sink_size = int (attention_sink_params [0 ])
235237 window_size = int (attention_sink_params [1 ])
236238 eviction_batch_size = int (attention_sink_params [2 ])
237239
238- assert self .args .max_context_length == sink_size + window_size
240+ assert self .llm_config . export .max_context_length == sink_size + window_size
239241
240242 self .model_ = enable_attention_sink (
241243 module = self .model_ ,
@@ -326,20 +328,19 @@ def get_example_inputs_kvcache_sdpa(self):
326328 )
327329
328330 def _transform_for_pre_quantization (self , checkpoint , model_args ):
329- assert hasattr ( self .args , " preq_mode" ) , "preq_mode must be specified"
330- assert self .args .preq_mode in [
331+ assert self .llm_config and self . llm_config . base . preq_mode , "preq_mode must be specified"
332+ assert self .llm_config . base .preq_mode in [
331333 "8da4w" ,
332334 "8da4w_output_8da8w" ,
333- ], f"Quantization mode { self .args .preq_mode } is not compatible with SpinQuant."
334- assert hasattr (
335- self .args , "preq_group_size"
336- ), "preq_group_size must be specified"
337- assert hasattr (self .args , "dtype_override" ), "dtype_override must be specified"
335+ ], f"Quantization mode { self .llm_config .base .preq_mode } is not compatible with SpinQuant."
336+ assert self .llm_config .base .preq_group_size , "preq_group_size must be specified"
337+ assert self .llm_config .model .dtype_override , "dtype_override must be specified"
338+
338339 from .source_transformation .pre_quantization import (
339340 transform_linear_for_pre_quantization ,
340341 )
341342
342- assert self .args .preq_group_size == model_args .quantization_args ["group_size" ]
343+ assert self .llm_config . base .preq_group_size == model_args .quantization_args ["group_size" ]
343344
344345 mapping = {
345346 "fp32" : torch .float32 ,
@@ -348,28 +349,28 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
348349 }
349350
350351 # Transform the output layer first if needed.
351- if self .args .preq_mode == "8da4w_output_8da8w" :
352+ if self .llm_config . base .preq_mode == "8da4w_output_8da8w" :
352353 from .source_transformation .pre_quantization import (
353354 transform_output_linear_for_pre_quantization ,
354355 )
355356
356357 self .model_ = transform_output_linear_for_pre_quantization (
357358 module = self .model_ ,
358359 checkpoint = checkpoint ,
359- dtype = mapping [self .args .dtype_override ],
360+ dtype = mapping [self .llm_config . model .dtype_override ],
360361 )
361362
362363 self .model_ = transform_linear_for_pre_quantization (
363364 self .model_ ,
364365 checkpoint ,
365- self .args .preq_group_size ,
366- mapping [self .args .dtype_override ],
366+ self .llm_config . base .preq_group_size ,
367+ mapping [self .llm_config . model .dtype_override ],
367368 )
368369
369370 embedding_bit_width , embedding_group_size = None , None
370- if hasattr ( self .args , " preq_embedding_quantize" ) :
371+ if self .llm_config . base . preq_embedding_quantize :
371372 embedding_bit_width , embedding_group_size = (
372- self .args .preq_embedding_quantize .split ("," )
373+ self .llm_config . base .preq_embedding_quantize .split ("," )
373374 )
374375 from .source_transformation .pre_quantization import (
375376 transform_embedding_for_pre_quantization ,
@@ -387,7 +388,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args):
387388 self .model_ = transform_embedding_for_pre_quantization (
388389 self .model_ ,
389390 checkpoint ,
390- mapping [self .args .dtype_override ],
391+ mapping [self .llm_config . model .dtype_override ],
391392 int (embedding_bit_width ),
392393 embedding_group_size ,
393394 )
0 commit comments