Skip to content

Commit eeeeb8a

Browse files
committed
Add runner
1 parent aa289ea commit eeeeb8a

File tree

2 files changed

+186
-0
lines changed

2 files changed

+186
-0
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import json
9+
from typing import Optional
10+
11+
import torch
12+
13+
from executorch.examples.models.llama.export_llama_lib import (
14+
_prepare_for_llama_export,
15+
build_args_parser as _build_args_parser,
16+
TORCHTUNE_DEFINED_MODELS,
17+
)
18+
from executorch.examples.models.llama3_2_vision.runner.generation import TorchTuneLlamaRunner
19+
from executorch.extension.llm.export import LLMEdgeManager
20+
21+
22+
class EagerLlamaRunner(TorchTuneLlamaRunner):
23+
"""
24+
Runs llama in eager mode with provided checkpoint file.
25+
"""
26+
27+
def __init__(self, args):
28+
with open(args.params, "r") as f:
29+
params = json.loads(f.read())
30+
super().__init__(
31+
tokenizer_path=args.tokenizer_path,
32+
max_seq_len=args.max_seq_length,
33+
max_batch_size=1,
34+
use_kv_cache=args.use_kv_cache,
35+
vocab_size=params["vocab_size"],
36+
device="cuda" if torch.cuda.is_available() else "cpu",
37+
)
38+
manager: LLMEdgeManager = _prepare_for_llama_export(args)
39+
self.model = manager.model.eval().to(device=self.device)
40+
41+
def forward(
42+
self,
43+
tokens: Optional[torch.LongTensor] = None,
44+
input_pos: Optional[torch.LongTensor] = None,
45+
mask: Optional[torch.LongTensor] = None,
46+
) -> torch.Tensor:
47+
return self.model.forward(tokens=tokens, input_pos=input_pos, mask=mask)
48+
49+
50+
def build_args_parser() -> argparse.ArgumentParser:
51+
parser = _build_args_parser()
52+
53+
parser.add_argument(
54+
"--prompt",
55+
type=str,
56+
default="Hello",
57+
)
58+
59+
parser.add_argument(
60+
"--temperature",
61+
type=float,
62+
default=0,
63+
)
64+
65+
return parser
66+
67+
68+
def main() -> None:
69+
parser = build_args_parser()
70+
args = parser.parse_args()
71+
72+
runner = EagerLlamaRunner(args)
73+
result = runner.text_completion(
74+
prompt=args.prompt,
75+
temperature=args.temperature,
76+
)
77+
print(
78+
"Response: \n{response}\n Tokens:\n {tokens}".format(
79+
response=result["generation"], tokens=result["tokens"]
80+
)
81+
)
82+
83+
84+
if __name__ == "__main__":
85+
main() # pragma: no cover
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from abc import ABC, abstractmethod
8+
from typing import List, Optional, TypedDict
9+
10+
import torch
11+
12+
from executorch.extension.llm.tokenizer.utils import get_tokenizer
13+
from executorch.examples.models.llama.runner.generation import LlamaRunner, next_token, sample_top_p
14+
15+
16+
class TorchTuneLlamaRunner(LlamaRunner):
17+
def __init__(
18+
self,
19+
tokenizer_path: str,
20+
max_seq_len: int,
21+
max_batch_size: int,
22+
use_kv_cache: bool,
23+
vocab_size: int,
24+
device: str = "cpu",
25+
):
26+
super().__init__(
27+
tokenizer_path,
28+
max_seq_len,
29+
max_batch_size,
30+
use_kv_cache,
31+
vocab_size,
32+
device,
33+
)
34+
35+
self.causal_mask = torch.tril(
36+
torch.ones(
37+
size=(max_seq_len, max_seq_len),
38+
dtype=torch.bool,
39+
)
40+
)
41+
self.input_pos = torch.arange(max_seq_len)
42+
43+
def generate( # noqa: C901
44+
self,
45+
prompt_tokens: List[int],
46+
max_seq_len: int,
47+
temperature: float = 0.8,
48+
top_p: float = 0.9,
49+
echo: bool = False,
50+
) -> List[int]:
51+
# Prefill
52+
seq_len = len(prompt_tokens)
53+
input_pos = self.input_pos[None, :seq_len]
54+
mask = self.causal_mask[None, :seq_len]
55+
if self.use_kv_cache:
56+
logits = self.forward(
57+
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
58+
input_pos=input_pos,
59+
mask=mask,
60+
)
61+
else:
62+
logits = self.forward(
63+
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
64+
)
65+
66+
# Only need the last logit.
67+
current_token = next_token(logits[:, -1, :], temperature, top_p)
68+
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
69+
tokens = prompt_tokens + [current_token]
70+
71+
while len(tokens) < max_seq_len:
72+
mask = self.causal_mask[None, seq_len, None, :]
73+
input_pos = self.input_pos[None, seq_len, None]
74+
if self.use_kv_cache:
75+
logits = self.forward(
76+
tokens=torch.tensor(
77+
[[current_token]], dtype=torch.long, device=self.device
78+
),
79+
input_pos=input_pos,
80+
mask=mask,
81+
)
82+
else:
83+
logits = self.forward(
84+
tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
85+
)
86+
87+
# Only need the last logit.
88+
current_token = next_token(logits[:, -1, :], temperature, top_p)
89+
tokens.append(current_token)
90+
91+
if current_token == self.tokenizer.eos_id or (
92+
hasattr(self.tokenizer, "stop_tokens")
93+
and current_token in self.tokenizer.stop_tokens
94+
):
95+
break
96+
97+
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
98+
seq_len += 1
99+
100+
return tokens if echo else tokens[len(prompt_tokens) :]
101+

0 commit comments

Comments
 (0)