diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 0a4ecb73199..c1e69a47d3d 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1226,11 +1226,6 @@ def _load_llama_model( EagerModelFactory.create_model( module_name, model_class_name, - checkpoint=checkpoint, - checkpoint_dir=checkpoint_dir, - params=params_path, - fairseq2=weight_type == WeightType.FAIRSEQ2, - dtype=torch_dtype, llm_config=llm_config, ) ) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 4d441a5a32e..d11cf06352b 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -47,30 +47,18 @@ def __init__(self, **kwargs): # Params file. params_path = kwargs.get("params", None) - self.llm_config = kwargs.get("llm_config", None) + self.llm_config = kwargs.get("llm_config") + assert self.llm_config is not None, "llm_config must be provided" - # Set all parameters from llm_config if available, otherwise use kwargs as fallback - if self.llm_config: - self.use_kv_cache = self.llm_config.model.use_kv_cache - self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache - self.generate_full_logits = self.llm_config.debug.generate_full_logits - self.enable_dynamic_shape = self.llm_config.model.enable_dynamic_shape - self.input_prune_map_path = self.llm_config.model.input_prune_map - self.output_prune_map_path = self.llm_config.model.output_prune_map - self.max_seq_len = self.llm_config.export.max_seq_length - self.max_context_len = self.llm_config.export.max_context_length - self.verbose = self.llm_config.debug.verbose - else: - # Fallback to kwargs for backward compatibility - self.use_kv_cache = kwargs.get("use_kv_cache", False) - self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False) - self.generate_full_logits = kwargs.get("generate_full_logits", False) - self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) - self.input_prune_map_path = kwargs.get("input_prune_map_path", None) - self.output_prune_map_path = kwargs.get("output_prune_map_path", None) - self.max_seq_len = kwargs.get("max_seq_len", 128) - self.max_context_len = kwargs.get("max_context_len", 128) - self.verbose = kwargs.get("verbose", False) + self.use_kv_cache = self.llm_config.model.use_kv_cache + self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache + self.generate_full_logits = self.llm_config.debug.generate_full_logits + self.enable_dynamic_shape = self.llm_config.model.enable_dynamic_shape + self.input_prune_map_path = self.llm_config.model.input_prune_map + self.output_prune_map_path = self.llm_config.model.output_prune_map + self.max_seq_len = self.llm_config.export.max_seq_length + self.max_context_len = self.llm_config.export.max_context_length + self.verbose = self.llm_config.debug.verbose assert ( self.max_context_len >= self.max_seq_len @@ -173,7 +161,7 @@ def __init__(self, **kwargs): if model_args.use_scaled_rope: # Older models don't have use_scaled_rope configuration - model_name = str(self.llm_config.base.model_class) if self.llm_config else "llama3" + model_name = str(self.llm_config.base.model_class) assert model_name not in ["llama2", "stories110m"] # Llama3_2 and newer models in ExecuTorch repo should set larger scale factor @@ -212,7 +200,7 @@ def __init__(self, **kwargs): self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime( self.model_ ) - elif self.llm_config and self.llm_config.quantization.use_spin_quant: + elif self.llm_config.quantization.use_spin_quant: print("Using SPIN quantization.") self._transform_for_pre_quantization(checkpoint, model_args) @@ -221,7 +209,7 @@ def __init__(self, **kwargs): ) sanitize_checkpoint_from_pre_quantization(checkpoint) - elif self.llm_config and self.llm_config.quantization.use_qat: + elif self.llm_config.quantization.use_qat: print("Using QAT quantization.") self._transform_for_pre_quantization(checkpoint, model_args) if self.llm_config.base.use_lora: @@ -243,7 +231,7 @@ def __init__(self, **kwargs): sanitize_checkpoint_from_pre_quantization(checkpoint) - if self.llm_config and self.llm_config.model.use_attention_sink: + if self.llm_config.model.use_attention_sink: from .source_transformation.attention_sink import enable_attention_sink attention_sink_params = self.llm_config.model.use_attention_sink.split(",") @@ -343,7 +331,7 @@ def get_example_inputs_kvcache_sdpa(self): ) def _transform_for_pre_quantization(self, checkpoint, model_args): - assert self.llm_config and self.llm_config.base.preq_mode, "preq_mode must be specified" + assert self.llm_config.base.preq_mode, "preq_mode must be specified" assert self.llm_config.base.preq_mode in [ "8da4w", "8da4w_output_8da8w",