Skip to content

Commit a93e716

Browse files
authored
Add exported program runner (#6969)
1 parent 711f1c2 commit a93e716

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import json
9+
from typing import Optional
10+
11+
import torch
12+
13+
from executorch.examples.models.llama.export_llama_lib import (
14+
build_args_parser as _build_args_parser,
15+
)
16+
from executorch.examples.models.llama3_2_vision.runner.generation import (
17+
TorchTuneLlamaRunner,
18+
)
19+
20+
21+
class ExportedLlamaRunner(TorchTuneLlamaRunner):
22+
"""
23+
Runs a torch-exported .pt2 Llama.
24+
"""
25+
26+
def __init__(self, args):
27+
with open(args.params, "r") as f:
28+
params = json.loads(f.read())
29+
super().__init__(
30+
tokenizer_path=args.tokenizer_path,
31+
max_seq_len=args.max_seq_length,
32+
max_batch_size=1,
33+
use_kv_cache=args.use_kv_cache,
34+
vocab_size=params["vocab_size"],
35+
device="cuda" if torch.cuda.is_available() else "cpu",
36+
)
37+
print(f"Loading model from {args.pt2}")
38+
self.model = torch.export.load(args.pt2).module()
39+
print("Model loaded")
40+
41+
def forward(
42+
self,
43+
tokens: Optional[torch.LongTensor] = None,
44+
input_pos: Optional[torch.LongTensor] = None,
45+
mask: Optional[torch.LongTensor] = None,
46+
) -> torch.Tensor:
47+
if self.use_kv_cache:
48+
return self.model(tokens, input_pos=input_pos, mask=mask)
49+
else:
50+
return self.model(tokens)
51+
52+
53+
def build_args_parser() -> argparse.ArgumentParser:
54+
parser = _build_args_parser()
55+
56+
parser.add_argument(
57+
"--prompt",
58+
type=str,
59+
default="Hello",
60+
)
61+
62+
parser.add_argument(
63+
"--pt2",
64+
type=str,
65+
required=True,
66+
)
67+
68+
parser.add_argument(
69+
"--temperature",
70+
type=float,
71+
default=0,
72+
)
73+
74+
return parser
75+
76+
77+
def main() -> None:
78+
parser = build_args_parser()
79+
args = parser.parse_args()
80+
81+
runner = ExportedLlamaRunner(args)
82+
result = runner.text_completion(
83+
prompt=args.prompt,
84+
temperature=args.temperature,
85+
)
86+
print(
87+
"Response: \n{response}\n Tokens:\n {tokens}".format(
88+
response=result["generation"], tokens=result["tokens"]
89+
)
90+
)
91+
92+
93+
if __name__ == "__main__":
94+
main() # pragma: no cover

0 commit comments

Comments
 (0)