From 673e39fe3f48a873377a4fd1d1c4b502051652a5 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 20 May 2025 17:43:19 -0700 Subject: [PATCH] Introduce hydra framework with backwards compatibility [ghstack-poisoned] --- .../models/llama/config/llm_config_utils.py | 22 +++++++++++ examples/models/llama/export_llama.py | 38 ++++++++++++++----- examples/models/llama/export_llama_args.py | 21 ++++++++++ examples/models/llama/export_llama_hydra.py | 27 +++++++++++++ examples/models/llama/export_llama_lib.py | 24 +++++++++++- examples/models/llama/install_requirements.sh | 2 +- 6 files changed, 123 insertions(+), 11 deletions(-) create mode 100644 examples/models/llama/config/llm_config_utils.py create mode 100644 examples/models/llama/export_llama_args.py create mode 100644 examples/models/llama/export_llama_hydra.py diff --git a/examples/models/llama/config/llm_config_utils.py b/examples/models/llama/config/llm_config_utils.py new file mode 100644 index 00000000000..9c5178d26cb --- /dev/null +++ b/examples/models/llama/config/llm_config_utils.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +from executorch.examples.models.llama.config.llm_config import LlmConfig + + +def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig: + """ + To support legacy purposes, this function converts CLI args from + argparse to an LlmConfig, which is used by the LLM export process. + """ + llm_config = LlmConfig() + + # TODO: conversion code. + + return llm_config diff --git a/examples/models/llama/export_llama.py b/examples/models/llama/export_llama.py index e25a8a007eb..63e76e28ba9 100644 --- a/examples/models/llama/export_llama.py +++ b/examples/models/llama/export_llama.py @@ -4,30 +4,50 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Example script for exporting Llama2 to flatbuffer - -import logging - # force=True to ensure logging while in debugger. Set up logger before any # other imports. +import logging + FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT, force=True) +import argparse +import runpy import sys import torch -from .export_llama_lib import build_args_parser, export_llama - sys.setrecursionlimit(4096) +def parse_hydra_arg(): + """First parse out the arg for whether to use Hydra or the old CLI.""" + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument("--hydra", action="store_true") + args, remaining = parser.parse_known_args() + return args.hydra, remaining + + def main() -> None: seed = 42 torch.manual_seed(seed) - parser = build_args_parser() - args = parser.parse_args() - export_llama(args) + + use_hydra, remaining_args = parse_hydra_arg() + if use_hydra: + # The import runs the main function of export_llama_hydra with the remaining args + # under the Hydra framework. + sys.argv = [arg for arg in sys.argv if arg != "--hydra"] + print(f"running with {sys.argv}") + runpy.run_module( + "executorch.examples.models.llama.export_llama_hydra", run_name="__main__" + ) + else: + # Use the legacy version of the export_llama script which uses argsparse. + from executorch.examples.models.llama.export_llama_args import ( + main as export_llama_args_main, + ) + + export_llama_args_main(remaining_args) if __name__ == "__main__": diff --git a/examples/models/llama/export_llama_args.py b/examples/models/llama/export_llama_args.py new file mode 100644 index 00000000000..7a176d9b7d0 --- /dev/null +++ b/examples/models/llama/export_llama_args.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run export_llama with the legacy argparse setup. +""" + +from .export_llama_lib import build_args_parser, export_llama + + +def main(args) -> None: + parser = build_args_parser() + args = parser.parse_args(args) + export_llama(args) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/export_llama_hydra.py b/examples/models/llama/export_llama_hydra.py new file mode 100644 index 00000000000..73eca7e2a5a --- /dev/null +++ b/examples/models/llama/export_llama_hydra.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run export_llama using the new Hydra CLI. +""" + +import hydra + +from executorch.examples.models.llama.config.llm_config import LlmConfig +from executorch.examples.models.llama.export_llama_lib import export_llama +from hydra.core.config_store import ConfigStore + +cs = ConfigStore.instance() +cs.store(name="llm_config", node=LlmConfig) + + +@hydra.main(version_base=None, config_name="llm_config") +def main(llm_config: LlmConfig) -> None: + export_llama(llm_config) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index c430da78832..ec6977b8d1f 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -28,6 +28,10 @@ from executorch.devtools.backend_debug import print_delegation_info from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func + +from executorch.examples.models.llama.config.llm_config_utils import ( + convert_args_to_llm_config, +) from executorch.examples.models.llama.hf_download import ( download_and_convert_hf_checkpoint, ) @@ -51,6 +55,7 @@ get_vulkan_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace +from omegaconf.dictconfig import DictConfig from ..model_factory import EagerModelFactory from .source_transformation.apply_spin_quant_r1_r2 import ( @@ -568,7 +573,24 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: return return_val -def export_llama(args) -> str: +def export_llama( + export_options: Union[argparse.Namespace, DictConfig], +) -> str: + if isinstance(export_options, argparse.Namespace): + # Legacy CLI. + args = export_options + llm_config = convert_args_to_llm_config(export_options) # noqa: F841 + elif isinstance(export_options, DictConfig): + # Hydra CLI. + llm_config = export_options # noqa: F841 + pass + else: + raise ValueError( + "Input to export_llama must be either of type argparse.Namespace or LlmConfig" + ) + + # TODO: refactor rest of export_llama to use llm_config instead of args. + # If a checkpoint isn't provided for an HF OSS model, download and convert the # weights first. if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS: diff --git a/examples/models/llama/install_requirements.sh b/examples/models/llama/install_requirements.sh index b9e0f9210c5..580a152a322 100755 --- a/examples/models/llama/install_requirements.sh +++ b/examples/models/llama/install_requirements.sh @@ -10,7 +10,7 @@ # Install tokenizers for hf .json tokenizer. # Install snakeviz for cProfile flamegraph # Install lm-eval for Model Evaluation with lm-evalution-harness. -pip install huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile +pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile # Call the install helper for further setup python examples/models/llama/install_requirement_helper.py