|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -import argparse |
8 | 7 | import json |
9 | 8 | from typing import Optional |
10 | 9 |
|
11 | 10 | import torch |
12 | 11 |
|
13 | | -from executorch.examples.models.llama.export_llama_lib import ( |
14 | | - _prepare_for_llama_export, |
15 | | - build_args_parser as _build_args_parser, |
16 | | -) |
| 12 | +from executorch.examples.models.llama.export_llama_lib import _prepare_for_llama_export |
| 13 | +from executorch.examples.models.llama.runner.eager import execute_runner |
17 | 14 | from executorch.examples.models.llama3_2_vision.runner.generation import ( |
18 | 15 | TorchTuneLlamaRunner, |
19 | 16 | ) |
@@ -48,38 +45,8 @@ def forward( |
48 | 45 | return self.model.forward(tokens=tokens, input_pos=input_pos, mask=mask) |
49 | 46 |
|
50 | 47 |
|
51 | | -def build_args_parser() -> argparse.ArgumentParser: |
52 | | - parser = _build_args_parser() |
53 | | - |
54 | | - parser.add_argument( |
55 | | - "--prompt", |
56 | | - type=str, |
57 | | - default="Hello", |
58 | | - ) |
59 | | - |
60 | | - parser.add_argument( |
61 | | - "--temperature", |
62 | | - type=float, |
63 | | - default=0, |
64 | | - ) |
65 | | - |
66 | | - return parser |
67 | | - |
68 | | - |
69 | 48 | def main() -> None: |
70 | | - parser = build_args_parser() |
71 | | - args = parser.parse_args() |
72 | | - |
73 | | - runner = EagerLlamaRunner(args) |
74 | | - result = runner.text_completion( |
75 | | - prompt=args.prompt, |
76 | | - temperature=args.temperature, |
77 | | - ) |
78 | | - print( |
79 | | - "Response: \n{response}\n Tokens:\n {tokens}".format( |
80 | | - response=result["generation"], tokens=result["tokens"] |
81 | | - ) |
82 | | - ) |
| 49 | + execute_runner(EagerLlamaRunner) |
83 | 50 |
|
84 | 51 |
|
85 | 52 | if __name__ == "__main__": |
|
0 commit comments