Skip to content

Commit 7d52002

Browse files
committed
Export model with KV cache + runner for Torchtune models
1 parent 37011d3 commit 7d52002

File tree

3 files changed

+17
-39
lines changed

3 files changed

+17
-39
lines changed

examples/models/llama/runner/generation.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -57,26 +57,12 @@ def __init__(
5757
max_batch_size: int,
5858
use_kv_cache: bool,
5959
vocab_size: int,
60-
has_full_logits: bool = False,
6160
device: str = "cpu",
6261
):
63-
"""
64-
Constructor.
65-
66-
Args:
67-
tokenizer_path: path to tokenizer.model file.
68-
max_seq_len: max length of the output sequence, after which the output will be clipped.
69-
max_batch_size: max batch size.
70-
use_kv_cache: whether to use a KV cache.
71-
vocab_size: number of items in the vocab.
72-
has_full_logits: whether the model returns the full logits or only returns the last logit.
73-
device: device to run the runner on.
74-
"""
7562
self.max_seq_len = max_seq_len
7663
self.max_batch_size = max_batch_size
7764
self.use_kv_cache = use_kv_cache
7865
self.tokenizer = get_tokenizer(tokenizer_path)
79-
self.has_full_logits = has_full_logits
8066
self.device = device
8167
assert vocab_size == self.tokenizer.n_words
8268

@@ -95,7 +81,7 @@ def generate( # noqa: C901
9581
top_p: float = 0.9,
9682
echo: bool = False,
9783
) -> List[int]:
98-
# prefill
84+
# Prefill
9985
logits = self.forward(
10086
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
10187
input_pos=(
@@ -105,11 +91,7 @@ def generate( # noqa: C901
10591
),
10692
)
10793

108-
current_token = next_token(logits[:, -1, :], temperature, top_p)
109-
if self.has_full_logits:
110-
current_token = next_token(logits[:, -1, :], temperature, top_p)
111-
else:
112-
current_token = next_token(logits, temperature, top_p)
94+
current_token = next_token(logits, temperature, top_p)
11395
tokens = prompt_tokens + [current_token]
11496

11597
i = 0
@@ -129,12 +111,7 @@ def generate( # noqa: C901
129111
tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
130112
)
131113

132-
# If the logits aren't already clipped to only contain the last logit, clip them.
133-
if self.has_full_logits:
134-
current_token = next_token(logits[:, -1, :], temperature, top_p)
135-
else:
136-
current_token = next_token(logits, temperature, top_p)
137-
114+
current_token = next_token(logits, temperature, top_p)
138115
if current_token == self.tokenizer.eos_id or (
139116
hasattr(self.tokenizer, "stop_tokens")
140117
and current_token in self.tokenizer.stop_tokens

examples/models/llama3_2_vision/model.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,12 @@ def __init__(self, **kwargs):
134134

135135
self.model_ = prune_output_vocab(self.model_, output_prune_map)
136136

137-
# if self.use_kv_cache:
138-
# print("Setting up KV cache on the model...")
139-
# self.model_.setup_caches(
140-
# batch_size=1,
141-
# dtype=self.dtype,
142-
# )
137+
if self.use_kv_cache:
138+
print("Setting up KV cache on the model...")
139+
self.model_.setup_caches(
140+
batch_size=1,
141+
dtype=self.dtype,
142+
)
143143

144144
def get_eager_model(self) -> torch.nn.Module:
145145
if self.dtype:
@@ -148,15 +148,16 @@ def get_eager_model(self) -> torch.nn.Module:
148148
return self.model_.to(torch.float16)
149149

150150
def get_example_inputs(self):
151-
return (torch.ones(1, 64, dtype=torch.long),)
151+
return (torch.ones(1, 32, dtype=torch.long),)
152152

153153
def get_example_kwarg_inputs(self):
154-
# TODO: add input_pos and mask when after making cache work.
154+
# For export we must use the prefill versions of the
155+
# causal mask and input_pos.
155156
return {
156-
# "mask": self.causal_mask[None, 64, None, :],
157+
"mask": self.causal_mask[None, :32],
157158
# "encoder_input": None,
158159
# "encoder_mask": None,
159-
# "input_pos": self.input_pos[None, 64]
160+
"input_pos": self.input_pos[None, :32]
160161
}
161162

162163
def get_dynamic_shapes(self):
@@ -166,7 +167,7 @@ def get_dynamic_shapes(self):
166167
"tokens": {0: batch_size, 1: dim_seq_len},
167168
# "encoder_input": {0: 1, 1: dim_enc, 2: 4096},
168169
# "encoder_mask": {0: 1, 1: dim, 2: dim_enc},
169-
# "mask": {0: batch_size, 1: dim_seq_len, 2: self.max_seq_len},
170-
# "input_pos" : {0: batch_size, 1: dim_seq_len},
170+
"mask": {0: batch_size, 1: dim_seq_len, 2: dim_seq_len},
171+
"input_pos" : {0: batch_size, 1: dim_seq_len},
171172
}
172173
return dynamic_shapes

examples/models/model_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def create_model(
4444
model = model_class(**kwargs)
4545
example_kwarg_inputs = None
4646
dynamic_shapes = None
47-
if hasattr(model, "get_example_kwarg_inputs()"):
47+
if hasattr(model, "get_example_kwarg_inputs"):
4848
example_kwarg_inputs = model.get_example_kwarg_inputs()
4949
if hasattr(model, "get_dynamic_shapes"):
5050
dynamic_shapes = model.get_dynamic_shapes()

0 commit comments

Comments
 (0)