Skip to content

Commit 7507002

Browse files
committed
Modularize and update base eager runner
1 parent 9bd405f commit 7507002

File tree

2 files changed

+10
-40
lines changed

2 files changed

+10
-40
lines changed

examples/models/llama/runner/eager.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import argparse
88
import json
9-
from typing import Optional
9+
from typing import Optional, Type
1010

1111
import torch
1212

@@ -77,11 +77,10 @@ def build_args_parser() -> argparse.ArgumentParser:
7777
return parser
7878

7979

80-
def main() -> None:
80+
def execute_runner(runner_class: Type[LlamaRunner]) -> None:
8181
parser = build_args_parser()
8282
args = parser.parse_args()
83-
84-
runner = EagerLlamaRunner(args)
83+
runner = runner_class(args)
8584
generated_tokens = (
8685
runner.chat_completion(temperature=args.temperature)
8786
if args.chat
@@ -95,5 +94,9 @@ def main() -> None:
9594
print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")
9695

9796

97+
def main() -> None:
98+
execute_runner(EagerLlamaRunner)
99+
100+
98101
if __name__ == "__main__":
99102
main() # pragma: no cover

examples/models/llama3_2_vision/runner/eager.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import argparse
87
import json
98
from typing import Optional
109

1110
import torch
1211

13-
from executorch.examples.models.llama.export_llama_lib import (
14-
_prepare_for_llama_export,
15-
build_args_parser as _build_args_parser,
16-
)
12+
from executorch.examples.models.llama.export_llama_lib import _prepare_for_llama_export
13+
from executorch.examples.models.llama.runner.eager import execute_runner
1714
from executorch.examples.models.llama3_2_vision.runner.generation import (
1815
TorchTuneLlamaRunner,
1916
)
@@ -48,38 +45,8 @@ def forward(
4845
return self.model.forward(tokens=tokens, input_pos=input_pos, mask=mask)
4946

5047

51-
def build_args_parser() -> argparse.ArgumentParser:
52-
parser = _build_args_parser()
53-
54-
parser.add_argument(
55-
"--prompt",
56-
type=str,
57-
default="Hello",
58-
)
59-
60-
parser.add_argument(
61-
"--temperature",
62-
type=float,
63-
default=0,
64-
)
65-
66-
return parser
67-
68-
6948
def main() -> None:
70-
parser = build_args_parser()
71-
args = parser.parse_args()
72-
73-
runner = EagerLlamaRunner(args)
74-
result = runner.text_completion(
75-
prompt=args.prompt,
76-
temperature=args.temperature,
77-
)
78-
print(
79-
"Response: \n{response}\n Tokens:\n {tokens}".format(
80-
response=result["generation"], tokens=result["tokens"]
81-
)
82-
)
49+
execute_runner(EagerLlamaRunner)
8350

8451

8552
if __name__ == "__main__":

0 commit comments

Comments
 (0)