Skip to content

Commit e44b259

Browse files
committed
Export with no kv cache + non-strict load checkpoint
1 parent 7d52002 commit e44b259

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

examples/models/llama3_2_vision/model.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(self, **kwargs):
109109
# Load checkpoint.
110110
missing, unexpected = self.model_.load_state_dict(
111111
checkpoint,
112-
strict=True,
112+
strict=False,
113113
assign=True,
114114
)
115115
if kwargs.get("verbose", False):
@@ -139,6 +139,7 @@ def __init__(self, **kwargs):
139139
self.model_.setup_caches(
140140
batch_size=1,
141141
dtype=self.dtype,
142+
decoder_max_seq_len=self.max_seq_len,
142143
)
143144

144145
def get_eager_model(self) -> torch.nn.Module:
@@ -153,21 +154,29 @@ def get_example_inputs(self):
153154
def get_example_kwarg_inputs(self):
154155
# For export we must use the prefill versions of the
155156
# causal mask and input_pos.
156-
return {
157-
"mask": self.causal_mask[None, :32],
158-
# "encoder_input": None,
159-
# "encoder_mask": None,
160-
"input_pos": self.input_pos[None, :32]
161-
}
157+
if self.use_kv_cache:
158+
return {
159+
"input_pos": self.input_pos[None, :32],
160+
"mask": self.causal_mask[None, :32],
161+
# "encoder_input": None,
162+
# "encoder_mask": None,
163+
}
164+
else:
165+
return None
162166

163167
def get_dynamic_shapes(self):
164168
batch_size = 1
165169
dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len)
166-
dynamic_shapes = {
167-
"tokens": {0: batch_size, 1: dim_seq_len},
168-
# "encoder_input": {0: 1, 1: dim_enc, 2: 4096},
169-
# "encoder_mask": {0: 1, 1: dim, 2: dim_enc},
170-
"mask": {0: batch_size, 1: dim_seq_len, 2: dim_seq_len},
171-
"input_pos" : {0: batch_size, 1: dim_seq_len},
172-
}
170+
if self.use_kv_cache:
171+
dynamic_shapes = {
172+
"tokens": {0: batch_size, 1: dim_seq_len},
173+
# "encoder_input": {0: 1, 1: dim_enc, 2: 4096},
174+
# "encoder_mask": {0: 1, 1: dim, 2: dim_enc},
175+
"mask": {0: batch_size, 1: dim_seq_len, 2: None},
176+
"input_pos" : {0: batch_size, 1: dim_seq_len},
177+
}
178+
else:
179+
dynamic_shapes = {
180+
"tokens": {0: batch_size, 1: dim_seq_len},
181+
}
173182
return dynamic_shapes

0 commit comments

Comments
 (0)