Skip to content

Commit 1e26f60

Browse files
committed
Add executorch runner
1 parent f504cc5 commit 1e26f60

File tree

1 file changed

+131
-0
lines changed
  • examples/models/llama3_2_vision/runner

1 file changed

+131
-0
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import json
9+
from typing import Optional
10+
11+
import torch
12+
13+
from executorch.examples.models.llama.export_llama_lib import (
14+
EXECUTORCH_DEFINED_MODELS,
15+
TORCHTUNE_DEFINED_MODELS,
16+
)
17+
from executorch.examples.models.llama3_2_vision.runner.generation import (
18+
TorchTuneLlamaRunner,
19+
)
20+
21+
from executorch.extension.pybindings.portable_lib import _load_for_executorch
22+
23+
# Load custom ops and quantized ops.
24+
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
25+
26+
# Note: import this after portable_lib
27+
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
28+
from executorch.kernels import quantized # noqa
29+
30+
31+
class NativeLlamaRunner(TorchTuneLlamaRunner):
32+
"""
33+
Runs llama via ExecuTorch with provided pte file.
34+
"""
35+
36+
def __init__(self, args):
37+
with open(args.params, "r") as f:
38+
params = json.loads(f.read())
39+
super().__init__(
40+
tokenizer_path=args.tokenizer,
41+
max_seq_len=args.max_len,
42+
max_batch_size=1,
43+
use_kv_cache=args.kv_cache,
44+
vocab_size=params["vocab_size"],
45+
)
46+
self.model = _load_for_executorch(args.pte)
47+
self.use_kv_cache = args.kv_cache
48+
49+
def forward(
50+
self,
51+
tokens: torch.Tensor,
52+
input_pos: Optional[torch.Tensor] = None,
53+
mask: Optional[torch.LongTensor] = None,
54+
) -> torch.Tensor:
55+
return (
56+
self.model.forward((tokens, input_pos, mask))
57+
if self.use_kv_cache
58+
else self.model.forward((tokens,))
59+
)[0]
60+
61+
62+
def build_args_parser() -> argparse.ArgumentParser:
63+
# TODO: merge these with build_args_parser from export_llama_lib.
64+
parser = argparse.ArgumentParser()
65+
66+
parser.add_argument(
67+
"--model",
68+
default="llama3",
69+
choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS,
70+
)
71+
72+
parser.add_argument(
73+
"-f",
74+
"--pte",
75+
type=str,
76+
default=None,
77+
help="path to exported executorch .pte file",
78+
)
79+
80+
parser.add_argument(
81+
"-p", "--params", type=str, default=None, help="model params file"
82+
)
83+
84+
parser.add_argument(
85+
"-t",
86+
"--tokenizer",
87+
type=str,
88+
default=None,
89+
)
90+
91+
parser.add_argument(
92+
"--prompt",
93+
type=str,
94+
default="Hello",
95+
)
96+
97+
parser.add_argument(
98+
"--temperature",
99+
type=float,
100+
default=0.6,
101+
)
102+
103+
parser.add_argument(
104+
"-kv",
105+
"--kv_cache",
106+
action="store_true",
107+
)
108+
109+
parser.add_argument(
110+
"--max_len",
111+
type=int,
112+
default=128,
113+
help="Maximum length of the generated response sequence.",
114+
)
115+
116+
return parser
117+
118+
119+
def main() -> None:
120+
parser = build_args_parser()
121+
args = parser.parse_args()
122+
runner = NativeLlamaRunner(args)
123+
generated_tokens = runner.text_completion(
124+
prompt=args.prompt,
125+
temperature=args.temperature,
126+
)
127+
print(f"Response: {generated_tokens}")
128+
129+
130+
if __name__ == "__main__":
131+
main() # pragma: no cover

0 commit comments

Comments
 (0)