|  | 
| 4 | 4 | # This source code is licensed under the BSD-style license found in the | 
| 5 | 5 | # LICENSE file in the root directory of this source tree. | 
| 6 | 6 | 
 | 
| 7 |  | -# Example script for exporting Llama2 to flatbuffer | 
| 8 |  | - | 
| 9 |  | -import logging | 
| 10 |  | - | 
| 11 | 7 | # force=True to ensure logging while in debugger. Set up logger before any | 
| 12 | 8 | # other imports. | 
|  | 9 | +import logging | 
|  | 10 | + | 
| 13 | 11 | FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" | 
| 14 | 12 | logging.basicConfig(level=logging.INFO, format=FORMAT, force=True) | 
| 15 | 13 | 
 | 
|  | 14 | +import argparse | 
|  | 15 | +import runpy | 
| 16 | 16 | import sys | 
| 17 | 17 | 
 | 
| 18 | 18 | import torch | 
| 19 | 19 | 
 | 
| 20 |  | -from .export_llama_lib import build_args_parser, export_llama | 
| 21 |  | - | 
| 22 | 20 | sys.setrecursionlimit(4096) | 
| 23 | 21 | 
 | 
| 24 | 22 | 
 | 
|  | 23 | +def parse_hydra_arg(): | 
|  | 24 | +    """First parse out the arg for whether to use Hydra or the old CLI.""" | 
|  | 25 | +    parser = argparse.ArgumentParser(add_help=True) | 
|  | 26 | +    parser.add_argument("--hydra", action="store_true") | 
|  | 27 | +    args, remaining = parser.parse_known_args() | 
|  | 28 | +    return args.hydra, remaining | 
|  | 29 | + | 
|  | 30 | + | 
| 25 | 31 | def main() -> None: | 
| 26 | 32 |     seed = 42 | 
| 27 | 33 |     torch.manual_seed(seed) | 
| 28 |  | -    parser = build_args_parser() | 
| 29 |  | -    args = parser.parse_args() | 
| 30 |  | -    export_llama(args) | 
|  | 34 | + | 
|  | 35 | +    use_hydra, remaining_args = parse_hydra_arg() | 
|  | 36 | +    if use_hydra: | 
|  | 37 | +        # The import runs the main function of export_llama_hydra with the remaining args | 
|  | 38 | +        # under the Hydra framework. | 
|  | 39 | +        sys.argv = [arg for arg in sys.argv if arg != "--hydra"] | 
|  | 40 | +        print(f"running with {sys.argv}") | 
|  | 41 | +        runpy.run_module( | 
|  | 42 | +            "executorch.examples.models.llama.export_llama_hydra", run_name="__main__" | 
|  | 43 | +        ) | 
|  | 44 | +    else: | 
|  | 45 | +        # Use the legacy version of the export_llama script which uses argsparse. | 
|  | 46 | +        from executorch.examples.models.llama.export_llama_args import ( | 
|  | 47 | +            main as export_llama_args_main, | 
|  | 48 | +        ) | 
|  | 49 | + | 
|  | 50 | +        export_llama_args_main(remaining_args) | 
| 31 | 51 | 
 | 
| 32 | 52 | 
 | 
| 33 | 53 | if __name__ == "__main__": | 
|  | 
0 commit comments