From 4e66fb5aade074b617398949d7c443a773fd496e Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 27 May 2025 16:41:45 -0700 Subject: [PATCH] refactor: Replace self.args with LlmConfig in model.py and export_llama_lib.py [ghstack-poisoned] --- examples/models/llama/export_llama_lib.py | 2 +- examples/models/llama/model.py | 53 ++++++++++++----------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 4f5631b6159..28a2ae3debb 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1237,7 +1237,7 @@ def _load_llama_model( input_prune_map_path=input_prune_map_path, output_prune_map_path=output_prune_map_path, dtype=torch_dtype, - args=args, + llm_config=llm_config, ) ) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index d6400c29db8..ac0119bcbe5 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -55,7 +55,7 @@ def __init__(self, **kwargs): self.output_prune_map_path = kwargs.get("output_prune_map_path", None) self.max_seq_len = kwargs.get("max_seq_len", 128) self.max_context_len = kwargs.get("max_context_len", 128) - self.args = kwargs.get("args", None) + self.llm_config = kwargs.get("llm_config", None) assert ( self.max_context_len >= self.max_seq_len @@ -158,10 +158,11 @@ def __init__(self, **kwargs): if model_args.use_scaled_rope: # Older models don't have use_scaled_rope configuration - assert self.args.model not in ["llama2", "stories110m"] + model_name = str(self.llm_config.base.model_class) if self.llm_config else "llama3" + assert model_name not in ["llama2", "stories110m"] # Llama3_2 and newer models in ExecuTorch repo should set larger scale factor - if self.args.model not in ["llama3", "llama3_1"]: + if model_name not in ["llama3", "llama3_1"]: model_args.rope_scale_factor = 32 if kwargs.get("verbose", False): @@ -196,7 +197,7 @@ def __init__(self, **kwargs): self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime( self.model_ ) - elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant: + elif self.llm_config and self.llm_config.quantization.use_spin_quant: print("Using SPIN quantization.") self._transform_for_pre_quantization(checkpoint, model_args) @@ -205,11 +206,12 @@ def __init__(self, **kwargs): ) sanitize_checkpoint_from_pre_quantization(checkpoint) - elif hasattr(self.args, "use_qat") and self.args.use_qat: + elif self.llm_config and self.llm_config.quantization.use_qat: print("Using QAT quantization.") self._transform_for_pre_quantization(checkpoint, model_args) - if hasattr(self.args, "use_lora") and self.args.use_lora: - assert model_args.lora_args["rank"] == self.args.use_lora + if self.llm_config.base.use_lora: + lora_rank = self.llm_config.base.use_lora + assert model_args.lora_args["rank"] == lora_rank from .source_transformation.lora import ( transform_linear_for_lora_after_quantization, ) @@ -217,7 +219,7 @@ def __init__(self, **kwargs): self.model_ = transform_linear_for_lora_after_quantization( self.model_, checkpoint, - self.args.use_lora, + lora_rank, ) from .source_transformation.pre_quantization import ( @@ -226,16 +228,16 @@ def __init__(self, **kwargs): sanitize_checkpoint_from_pre_quantization(checkpoint) - if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink: + if self.llm_config and self.llm_config.model.use_attention_sink: from .source_transformation.attention_sink import enable_attention_sink - attention_sink_params = self.args.use_attention_sink.split(",") + attention_sink_params = self.llm_config.model.use_attention_sink.split(",") assert len(attention_sink_params) == 3 sink_size = int(attention_sink_params[0]) window_size = int(attention_sink_params[1]) eviction_batch_size = int(attention_sink_params[2]) - assert self.args.max_context_length == sink_size + window_size + assert self.llm_config.export.max_context_length == sink_size + window_size self.model_ = enable_attention_sink( module=self.model_, @@ -326,20 +328,19 @@ def get_example_inputs_kvcache_sdpa(self): ) def _transform_for_pre_quantization(self, checkpoint, model_args): - assert hasattr(self.args, "preq_mode"), "preq_mode must be specified" - assert self.args.preq_mode in [ + assert self.llm_config and self.llm_config.base.preq_mode, "preq_mode must be specified" + assert self.llm_config.base.preq_mode in [ "8da4w", "8da4w_output_8da8w", - ], f"Quantization mode {self.args.preq_mode} is not compatible with SpinQuant." - assert hasattr( - self.args, "preq_group_size" - ), "preq_group_size must be specified" - assert hasattr(self.args, "dtype_override"), "dtype_override must be specified" + ], f"Quantization mode {self.llm_config.base.preq_mode} is not compatible with SpinQuant." + assert self.llm_config.base.preq_group_size, "preq_group_size must be specified" + assert self.llm_config.model.dtype_override, "dtype_override must be specified" + from .source_transformation.pre_quantization import ( transform_linear_for_pre_quantization, ) - assert self.args.preq_group_size == model_args.quantization_args["group_size"] + assert self.llm_config.base.preq_group_size == model_args.quantization_args["group_size"] mapping = { "fp32": torch.float32, @@ -348,7 +349,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): } # Transform the output layer first if needed. - if self.args.preq_mode == "8da4w_output_8da8w": + if self.llm_config.base.preq_mode == "8da4w_output_8da8w": from .source_transformation.pre_quantization import ( transform_output_linear_for_pre_quantization, ) @@ -356,20 +357,20 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): self.model_ = transform_output_linear_for_pre_quantization( module=self.model_, checkpoint=checkpoint, - dtype=mapping[self.args.dtype_override], + dtype=mapping[self.llm_config.model.dtype_override], ) self.model_ = transform_linear_for_pre_quantization( self.model_, checkpoint, - self.args.preq_group_size, - mapping[self.args.dtype_override], + self.llm_config.base.preq_group_size, + mapping[self.llm_config.model.dtype_override], ) embedding_bit_width, embedding_group_size = None, None - if hasattr(self.args, "preq_embedding_quantize"): + if self.llm_config.base.preq_embedding_quantize: embedding_bit_width, embedding_group_size = ( - self.args.preq_embedding_quantize.split(",") + self.llm_config.base.preq_embedding_quantize.split(",") ) from .source_transformation.pre_quantization import ( transform_embedding_for_pre_quantization, @@ -387,7 +388,7 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): self.model_ = transform_embedding_for_pre_quantization( self.model_, checkpoint, - mapping[self.args.dtype_override], + mapping[self.llm_config.model.dtype_override], int(embedding_bit_width), embedding_group_size, )