|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 |
| -# Example script for exporting Llama2 to flatbuffer |
8 |
| - |
9 |
| -import logging |
10 |
| - |
11 | 7 | # force=True to ensure logging while in debugger. Set up logger before any
|
12 | 8 | # other imports.
|
| 9 | +import logging |
| 10 | + |
13 | 11 | FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
|
14 | 12 | logging.basicConfig(level=logging.INFO, format=FORMAT, force=True)
|
15 | 13 |
|
| 14 | +import argparse |
| 15 | +import runpy |
16 | 16 | import sys
|
17 | 17 |
|
18 | 18 | import torch
|
19 | 19 |
|
20 |
| -from .export_llama_lib import build_args_parser, export_llama |
21 |
| - |
22 | 20 | sys.setrecursionlimit(4096)
|
23 | 21 |
|
24 | 22 |
|
| 23 | +def parse_hydra_arg(): |
| 24 | + """First parse out the arg for whether to use Hydra or the old CLI.""" |
| 25 | + parser = argparse.ArgumentParser(add_help=True) |
| 26 | + parser.add_argument("--hydra", action="store_true") |
| 27 | + args, remaining = parser.parse_known_args() |
| 28 | + return args.hydra, remaining |
| 29 | + |
| 30 | + |
25 | 31 | def main() -> None:
|
26 | 32 | seed = 42
|
27 | 33 | torch.manual_seed(seed)
|
28 |
| - parser = build_args_parser() |
29 |
| - args = parser.parse_args() |
30 |
| - export_llama(args) |
| 34 | + |
| 35 | + use_hydra, remaining_args = parse_hydra_arg() |
| 36 | + if use_hydra: |
| 37 | + # The import runs the main function of export_llama_hydra with the remaining args |
| 38 | + # under the Hydra framework. |
| 39 | + sys.argv = [arg for arg in sys.argv if arg != "--hydra"] |
| 40 | + print(f"running with {sys.argv}") |
| 41 | + runpy.run_module( |
| 42 | + "executorch.examples.models.llama.export_llama_hydra", run_name="__main__" |
| 43 | + ) |
| 44 | + else: |
| 45 | + # Use the legacy version of the export_llama script which uses argsparse. |
| 46 | + from executorch.examples.models.llama.export_llama_args import ( |
| 47 | + main as export_llama_args_main, |
| 48 | + ) |
| 49 | + |
| 50 | + export_llama_args_main(remaining_args) |
31 | 51 |
|
32 | 52 |
|
33 | 53 | if __name__ == "__main__":
|
|
0 commit comments