Skip to content

Commit d595735

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Add exported program runner (#6969)
Summary: Add an `ExportedProgram` runner for TorchTune Llama. Test Plan: ``` # Download resources tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct # Export model python -m examples.models.llama.export_llama --model llama3_2_vision --checkpoint /tmp/Llama-3.2-11B-Vision-Instruct/original/consolidated.pth --params examples/models/llama3_2_vision/text_decoder/params/demo_config.json --metadata '{"append_eos_to_prompt": 0, "get_bos_id":128000, "get_eos_ids":[128009, 128001], "get_n_bos": 0, "get_n_eos": 0}' --output_name="llama3_2_vision.pt2" -d fp32 --verbose --max_seq_length 64 --export_only -kv # Run ExportedProgram python -m examples.models.llama3_2_vision.runner.exported --model llama3_2_vision --pt2 llama3_2_vision.pt2 --tokenizer /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model --prompt "How many calories are in bread?" --params examples/models/llama3_2_vision/text_decoder/params/demo_config.json --max_seq_length 64 -kv ``` Output: ``` The number of calories in bread can vary greatly depending on the type of bread, its ingredients, and its size. Here are the approximate calorie counts for different types of bread: White bread: 80-100 calories per slice ``` Differential Revision: D66186052 Pulled By: dvorjackz
1 parent 711f1c2 commit d595735

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import json
9+
from typing import Optional
10+
11+
import torch
12+
13+
from executorch.examples.models.llama.export_llama_lib import (
14+
build_args_parser as _build_args_parser,
15+
)
16+
from executorch.examples.models.llama3_2_vision.runner.generation import (
17+
TorchTuneLlamaRunner,
18+
)
19+
20+
21+
class ExportedLlamaRunner(TorchTuneLlamaRunner):
22+
"""
23+
Runs a torch-exported .pt2 Llama.
24+
"""
25+
26+
def __init__(self, args):
27+
with open(args.params, "r") as f:
28+
params = json.loads(f.read())
29+
super().__init__(
30+
tokenizer_path=args.tokenizer_path,
31+
max_seq_len=args.max_seq_length,
32+
max_batch_size=1,
33+
use_kv_cache=args.use_kv_cache,
34+
vocab_size=params["vocab_size"],
35+
device="cuda" if torch.cuda.is_available() else "cpu",
36+
)
37+
print(f"Loading model from {args.pt2}")
38+
self.model = torch.export.load(args.pt2).module()
39+
print("Model loaded")
40+
41+
def forward(
42+
self,
43+
tokens: Optional[torch.LongTensor] = None,
44+
input_pos: Optional[torch.LongTensor] = None,
45+
mask: Optional[torch.LongTensor] = None,
46+
) -> torch.Tensor:
47+
if self.use_kv_cache:
48+
return self.model(tokens, input_pos=input_pos, mask=mask)
49+
else:
50+
return self.model(tokens)
51+
52+
53+
def build_args_parser() -> argparse.ArgumentParser:
54+
parser = _build_args_parser()
55+
56+
parser.add_argument(
57+
"--prompt",
58+
type=str,
59+
default="Hello",
60+
)
61+
62+
parser.add_argument(
63+
"--pt2",
64+
type=str,
65+
required=True,
66+
)
67+
68+
parser.add_argument(
69+
"--temperature",
70+
type=float,
71+
default=0,
72+
)
73+
74+
return parser
75+
76+
77+
def main() -> None:
78+
parser = build_args_parser()
79+
args = parser.parse_args()
80+
81+
runner = ExportedLlamaRunner(args)
82+
result = runner.text_completion(
83+
prompt=args.prompt,
84+
temperature=args.temperature,
85+
)
86+
print(
87+
"Response: \n{response}\n Tokens:\n {tokens}".format(
88+
response=result["generation"], tokens=result["tokens"]
89+
)
90+
)
91+
92+
93+
if __name__ == "__main__":
94+
main() # pragma: no cover

0 commit comments

Comments
 (0)