diff --git a/backends/arm/test/models/test_llama.py b/backends/arm/test/models/test_llama.py index 1031768ec0b..69065341701 100644 --- a/backends/arm/test/models/test_llama.py +++ b/backends/arm/test/models/test_llama.py @@ -22,7 +22,9 @@ TosaPipelineMI, ) -from executorch.examples.models.llama.config.llm_config_utils import convert_args_to_llm_config +from executorch.examples.models.llama.config.llm_config_utils import ( + convert_args_to_llm_config, +) from executorch.examples.models.llama.export_llama_lib import ( build_args_parser, get_llama_model, diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 7933fe9ed3c..22233ac9c12 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -157,7 +157,8 @@ def build_model( argString = f"--model {model} --checkpoint {checkpoint} --params {params} {extra_opts} --output-dir {output_dir}" parser = build_args_parser() args = parser.parse_args(shlex.split(argString)) - return export_llama(args) + llm_config = convert_args_to_llm_config(args) + return export_llama(llm_config) def parse_list_of_ints(s): @@ -579,15 +580,10 @@ def export_llama( ) -> str: if isinstance(export_options, argparse.Namespace): # Legacy CLI. - args = export_options llm_config = convert_args_to_llm_config(export_options) elif isinstance(export_options, DictConfig): # Hydra CLI. llm_config = export_options - # Create an args object for backward compatibility during transition - args = argparse.Namespace() - for key, value in llm_config.items(): - setattr(args, key, value) else: raise ValueError( "Input to export_llama must be either of type argparse.Namespace or LlmConfig" @@ -626,7 +622,7 @@ def export_llama( from executorch.util.python_profiler import CProfilerFlameGraph with CProfilerFlameGraph(llm_config.debug.profile_path): - builder = _export_llama(llm_config, args) + builder = _export_llama(llm_config) assert ( filename := builder.get_saved_pte_filename() ) is not None, "Fail to get file name from builder" @@ -637,14 +633,14 @@ def export_llama( ) return "" else: - builder = _export_llama(llm_config, args) + builder = _export_llama(llm_config) assert ( filename := builder.get_saved_pte_filename() ) is not None, "Fail to get file name from builder" return filename -def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager: +def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: """ Helper function for export_llama. Loads the model from checkpoint and params, and sets up a LLMEdgeManager with initial transforms and dtype conversion. @@ -672,7 +668,7 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager: dtype_override = DType[llm_config.model.dtype_override] edge_manager = _load_llama_model( - llm_config.base.model_class, + llm_config, checkpoint=checkpoint_path, checkpoint_dir=checkpoint_dir, params_path=params_path, @@ -695,7 +691,6 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager: dtype_override=dtype_override, use_qnn=llm_config.backend.qnn.enabled, export_only=llm_config.export.export_only, - args=args, ) # At this point, the model is loaded in the default fp32. @@ -1054,7 +1049,7 @@ def _to_edge_and_lower_llama( # noqa: C901 return builder -def _export_llama(llm_config, args) -> LLMEdgeManager: # noqa: C901 +def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 _validate_args(llm_config) pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params( @@ -1066,7 +1061,7 @@ def _export_llama(llm_config, args) -> LLMEdgeManager: # noqa: C901 additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])] # export_to_edge - builder_exported = _prepare_for_llama_export(llm_config, args).export() + builder_exported = _prepare_for_llama_export(llm_config).export() builder_exported.run_canonical_optimizations() modelname = builder_exported.modelname @@ -1174,7 +1169,7 @@ def _load_llama_model_metadata( def _load_llama_model( - modelname: str = "llama3", + llm_config: LlmConfig, *, checkpoint: Optional[str] = None, checkpoint_dir: Optional[str] = None, @@ -1198,8 +1193,6 @@ def _load_llama_model( dtype_override: Optional[DType] = None, use_qnn: bool = False, export_only: bool = False, - args, - llm_config: Optional[LlmConfig] = None, ) -> "LLMEdgeManager": """ A helper util that builds a Llama2 model. It returns a LLMEdgeManager that @@ -1208,6 +1201,7 @@ def _load_llama_model( An instance of LLMEdgeManager which contains the eager mode model. """ + modelname = llm_config.base.model_class if modelname in EXECUTORCH_DEFINED_MODELS: module_name = "llama" model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py. @@ -1220,13 +1214,11 @@ def _load_llama_model( else: raise ValueError(f"{modelname} is not a valid Llama model.") - torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None - model, example_inputs, example_kwarg_inputs, dynamic_shapes = ( EagerModelFactory.create_model( module_name, model_class_name, - model_args={"llm_config": llm_config}, + llm_config=llm_config, ) ) diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 7a3e5b80e79..5d5d31f18be 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -16,6 +16,7 @@ get_default_model_resource_dir, ) +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope @@ -36,11 +37,11 @@ def convert_to_llama_checkpoint(**kwargs): class Llama2Model(EagerModelBase): - def __init__(self, llm_config): + def __init__(self, llm_config: LlmConfig): resource_dir = get_default_model_resource_dir(__file__) self.llm_config = llm_config - + # Use single checkpoint file. checkpoint_path = self.llm_config.base.checkpoint # Check if checkpoint_dir was provided for a sharded checkpoint. @@ -48,7 +49,7 @@ def __init__(self, llm_config): # Params file. params_path = self.llm_config.base.params - + self.use_kv_cache = self.llm_config.model.use_kv_cache self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache self.generate_full_logits = self.llm_config.debug.generate_full_logits @@ -101,7 +102,7 @@ def __init__(self, llm_config): checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True) # If given checkpoint is fairseq, convert to llama checkpoint. - fairseq2_checkpoint = kwargs.get("fairseq2", False) + fairseq2_checkpoint = llm_config.base.fairseq2 if fairseq2_checkpoint: print("Using fairseq2 checkpoint") checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint) @@ -337,12 +338,15 @@ def _transform_for_pre_quantization(self, checkpoint, model_args): ], f"Quantization mode {self.llm_config.base.preq_mode} is not compatible with SpinQuant." assert self.llm_config.base.preq_group_size, "preq_group_size must be specified" assert self.llm_config.model.dtype_override, "dtype_override must be specified" - + from .source_transformation.pre_quantization import ( transform_linear_for_pre_quantization, ) - assert self.llm_config.base.preq_group_size == model_args.quantization_args["group_size"] + assert ( + self.llm_config.base.preq_group_size + == model_args.quantization_args["group_size"] + ) mapping = { "fp32": torch.float32,