Skip to content

Commit b7c8315

Browse files
committed
Add Llama3.2 1B as an new example model
1 parent e44b259 commit b7c8315

File tree

7 files changed

+447
-1
lines changed

7 files changed

+447
-1
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878

7979

8080
EXECUTORCH_DEFINED_MODELS = ["llama2", "llama3", "llama3_1", "llama3_2"]
81-
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"]
81+
TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision", "llama3_2_tt"]
8282

8383

8484
class WeightType(Enum):
@@ -811,6 +811,9 @@ def _load_llama_model(
811811
elif modelname in TORCHTUNE_DEFINED_MODELS:
812812
if modelname == "llama3_2_vision":
813813
model_class_name = "Llama3_2Decoder"
814+
if modelname == "llama3_2_tt":
815+
modelname = "llama3_2"
816+
model_class_name = "Llama3_2"
814817
else:
815818
raise ValueError(f"{modelname} is not a valid Llama model.")
816819

examples/models/llama/runner/generation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,15 @@ def text_completion(
146146
This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
147147
"""
148148
prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
149+
print(f"Encoded prompt: {prompt_tokens}")
150+
print("Generating")
149151
generation_tokens = self.generate(
150152
prompt_tokens=prompt_tokens,
151153
temperature=temperature,
152154
top_p=top_p,
153155
echo=echo,
154156
)
157+
print("Generated")
155158
return {
156159
"generation": self.tokenizer.decode(generation_tokens),
157160
"tokens": generation_tokens,
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import Llama3_2
8+
9+
__all__ = [Llama3_2]

examples/models/llama3_2/model.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import json
10+
from typing import Any, Dict
11+
12+
import torch
13+
from executorch.examples.models.checkpoint import (
14+
get_checkpoint_dtype,
15+
get_default_model_resource_dir,
16+
)
17+
18+
from executorch.examples.models.model_base import EagerModelBase
19+
from torchtune.models.llama3_2._model_builders import llama3_2_1b
20+
from torchtune.models.convert_weights import meta_to_tune
21+
22+
23+
class Llama3_2(EagerModelBase):
24+
"""
25+
Llama3.2 as from TorchTune.
26+
"""
27+
28+
def __init__(self, **kwargs):
29+
# Set member vars from kwargs.
30+
self.max_seq_len = kwargs.get(
31+
"max_seq_len", 8192
32+
) # Trained to be a lot larger, but this value is kept small because of static kv cache at the moment.
33+
self.encoder_max_seq_len = kwargs.get(
34+
"encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1)
35+
) # Same as above.
36+
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
37+
self.use_kv_cache = kwargs.get("use_kv_cache", False)
38+
self.verbose = kwargs.get("verbose", False)
39+
self.args = kwargs.get("args", None)
40+
41+
ckpt_dir = get_default_model_resource_dir(__file__)
42+
# Single checkpoint file.
43+
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
44+
# Sharded checkpoint.
45+
checkpoint_dir = kwargs.get("checkpoint_dir", None)
46+
params_path = kwargs.get("params", ckpt_dir / "demo_config.json")
47+
48+
self.causal_mask = torch.tril(
49+
torch.ones(
50+
size=(self.max_seq_len, self.max_seq_len),
51+
dtype=torch.bool,
52+
)
53+
)
54+
self.input_pos = torch.arange(self.max_seq_len)
55+
56+
# Load checkpoint and params.
57+
device = "cpu"
58+
if checkpoint_dir is not None:
59+
raise NotImplementedError(
60+
"Sharded checkpoint not yet supported for Llama3_2Decoder."
61+
)
62+
else:
63+
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
64+
checkpoint = meta_to_tune(checkpoint)
65+
with open(params_path, "r") as f:
66+
params = json.loads(f.read())
67+
68+
# Find dtype from checkpoint. (skip for now)
69+
self.dtype = get_checkpoint_dtype(checkpoint)
70+
71+
# Load model.
72+
self.model_ = llama3_2_1b()
73+
74+
# Save params for future use.
75+
for param_name, param_val in params.items():
76+
setattr(self.model_, param_name, param_val)
77+
78+
# Quantize. (skip for now)
79+
80+
# Load checkpoint.
81+
missing, unexpected = self.model_.load_state_dict(
82+
checkpoint,
83+
strict=False,
84+
assign=True,
85+
)
86+
if kwargs.get("verbose", False):
87+
print("============= missing keys ================")
88+
print(missing)
89+
print("============= /missing ================")
90+
print("============= unexpected keys ================")
91+
print(unexpected)
92+
print("============= /unexpected ================")
93+
94+
# Prune the output layer if output_prune_map is provided.
95+
output_prune_map = None
96+
if self.output_prune_map_path is not None:
97+
from executorch.examples.models.llama2.source_transformation.prune_output import (
98+
prune_output_vocab,
99+
)
100+
101+
with open(self.output_prune_map_path, "r") as f:
102+
output_prune_map = json.load(f)
103+
# Change keys from string to int (json only supports string keys)
104+
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
105+
106+
self.model_ = prune_output_vocab(self.model_, output_prune_map)
107+
108+
if self.use_kv_cache:
109+
print("Setting up KV cache on the model...")
110+
self.model_.setup_caches(
111+
batch_size=1,
112+
dtype=self.dtype,
113+
decoder_max_seq_len=self.max_seq_len,
114+
)
115+
116+
def get_eager_model(self) -> torch.nn.Module:
117+
if self.dtype:
118+
return self.model_.to(self.dtype)
119+
else:
120+
return self.model_.to(torch.float16)
121+
122+
def get_example_inputs(self):
123+
return (torch.ones(1, 32, dtype=torch.long),)
124+
125+
def get_example_kwarg_inputs(self):
126+
# For export we must use the prefill versions of the
127+
# causal mask and input_pos.
128+
if self.use_kv_cache:
129+
return {
130+
"input_pos": self.input_pos[None, :32],
131+
"mask": self.causal_mask[None, :32],
132+
}
133+
else:
134+
return None
135+
136+
def get_dynamic_shapes(self):
137+
batch_size = 1
138+
dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len)
139+
if self.use_kv_cache:
140+
dynamic_shapes = {
141+
"tokens": {0: batch_size, 1: dim_seq_len},
142+
"input_pos" : {0: batch_size, 1: dim_seq_len},
143+
"mask": {0: batch_size, 1: dim_seq_len, 2: None},
144+
}
145+
else:
146+
dynamic_shapes = {
147+
"tokens": {0: batch_size, 1: dim_seq_len},
148+
}
149+
return dynamic_shapes
150+
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import json
9+
from typing import Optional
10+
11+
import torch
12+
13+
from executorch.examples.models.llama.export_llama_lib import (
14+
_prepare_for_llama_export,
15+
build_args_parser as _build_args_parser,
16+
TORCHTUNE_DEFINED_MODELS,
17+
)
18+
from executorch.examples.models.llama3_2_vision.runner.generation import TorchTuneLlamaRunner
19+
from executorch.extension.llm.export import LLMEdgeManager
20+
21+
22+
class EagerLlamaRunner(TorchTuneLlamaRunner):
23+
"""
24+
Runs llama in eager mode with provided checkpoint file.
25+
"""
26+
27+
def __init__(self, args):
28+
with open(args.params, "r") as f:
29+
params = json.loads(f.read())
30+
super().__init__(
31+
tokenizer_path=args.tokenizer_path,
32+
max_seq_len=args.max_seq_length,
33+
max_batch_size=1,
34+
use_kv_cache=args.use_kv_cache,
35+
vocab_size=params["vocab_size"],
36+
device="cuda" if torch.cuda.is_available() else "cpu",
37+
)
38+
manager: LLMEdgeManager = _prepare_for_llama_export(args)
39+
self.model = manager.model.eval().to(device=self.device)
40+
41+
def forward(
42+
self,
43+
tokens: Optional[torch.LongTensor] = None,
44+
input_pos: Optional[torch.LongTensor] = None,
45+
mask: Optional[torch.LongTensor] = None,
46+
) -> torch.Tensor:
47+
return self.model.forward(tokens=tokens, input_pos=input_pos, mask=mask)
48+
49+
50+
def build_args_parser() -> argparse.ArgumentParser:
51+
parser = _build_args_parser()
52+
53+
parser.add_argument(
54+
"--prompt",
55+
type=str,
56+
default="Hello",
57+
)
58+
59+
parser.add_argument(
60+
"--temperature",
61+
type=float,
62+
default=0,
63+
)
64+
65+
return parser
66+
67+
68+
def main() -> None:
69+
parser = build_args_parser()
70+
args = parser.parse_args()
71+
72+
runner = EagerLlamaRunner(args)
73+
result = runner.text_completion(
74+
prompt=args.prompt,
75+
temperature=args.temperature,
76+
)
77+
print(
78+
"Response: \n{response}\n Tokens:\n {tokens}".format(
79+
response=result["generation"], tokens=result["tokens"]
80+
)
81+
)
82+
83+
84+
if __name__ == "__main__":
85+
main() # pragma: no cover
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import json
9+
from typing import Optional
10+
11+
import torch
12+
13+
from executorch.examples.models.llama.export_llama_lib import (
14+
_prepare_for_llama_export,
15+
build_args_parser as _build_args_parser,
16+
TORCHTUNE_DEFINED_MODELS,
17+
)
18+
from executorch.examples.models.llama3_2_vision.runner.generation import TorchTuneLlamaRunner
19+
from executorch.extension.llm.export import LLMEdgeManager
20+
21+
22+
class ExportedLlamaRunner(TorchTuneLlamaRunner):
23+
"""
24+
Runs a torch-exported .pt2 Llama.
25+
"""
26+
27+
def __init__(self, args):
28+
with open(args.params, "r") as f:
29+
params = json.loads(f.read())
30+
super().__init__(
31+
tokenizer_path=args.tokenizer_path,
32+
max_seq_len=args.max_seq_length,
33+
max_batch_size=1,
34+
use_kv_cache=args.use_kv_cache,
35+
vocab_size=params["vocab_size"],
36+
device="cuda" if torch.cuda.is_available() else "cpu",
37+
)
38+
print(f"Loading model from {args.pt2}")
39+
self.model = torch.export.load(args.pt2).module()
40+
print("Model loaded")
41+
42+
def forward(
43+
self,
44+
tokens: Optional[torch.LongTensor] = None,
45+
input_pos: Optional[torch.LongTensor] = None,
46+
mask: Optional[torch.LongTensor] = None,
47+
) -> torch.Tensor:
48+
print("Forward")
49+
if self.use_kv_cache:
50+
return self.model(tokens, input_pos=input_pos, mask=mask)
51+
else:
52+
return self.model(tokens)
53+
54+
def build_args_parser() -> argparse.ArgumentParser:
55+
parser = _build_args_parser()
56+
57+
parser.add_argument(
58+
"--prompt",
59+
type=str,
60+
default="Hello",
61+
)
62+
63+
parser.add_argument(
64+
"--pt2",
65+
type=str,
66+
required=True,
67+
)
68+
69+
parser.add_argument(
70+
"--temperature",
71+
type=float,
72+
default=0,
73+
)
74+
75+
return parser
76+
77+
78+
def main() -> None:
79+
parser = build_args_parser()
80+
args = parser.parse_args()
81+
82+
runner = ExportedLlamaRunner(args)
83+
result = runner.text_completion(
84+
prompt=args.prompt,
85+
temperature=args.temperature,
86+
)
87+
print(
88+
"Response: \n{response}\n Tokens:\n {tokens}".format(
89+
response=result["generation"], tokens=result["tokens"]
90+
)
91+
)
92+
93+
94+
if __name__ == "__main__":
95+
main() # pragma: no cover

0 commit comments

Comments
 (0)