Skip to content

Commit 595eb1c

Browse files
committed
Make inputs actually contiguously laid out in memory
1 parent 91afa6e commit 595eb1c

File tree

1 file changed

+13
-2
lines changed
  • examples/models/llama3_2_vision/text_decoder

1 file changed

+13
-2
lines changed

examples/models/llama3_2_vision/text_decoder/model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,22 @@ def get_example_inputs(self):
167167
def get_example_kwarg_inputs(self):
168168
# For export we must use the prefill versions of the
169169
# causal mask and input_pos.
170+
171+
# Make input_pos and mask contiguous in memory.
172+
input_pos = self.input_pos[None, : self.n_tokens]
173+
mask = self.causal_mask[None, : self.n_tokens]
174+
contiguous_input_pos = torch.empty_like(
175+
input_pos, memory_format=torch.contiguous_format
176+
)
177+
contiguous_input_pos.data.copy_(input_pos.data)
178+
contiguous_mask = torch.empty_like(mask, memory_format=torch.contiguous_format)
179+
contiguous_mask.data.copy_(mask.data)
180+
170181
# Hardcoding # of tiles to be 2. image tokens per tile is 1601.
171182
if self.use_kv_cache:
172183
return {
173-
"input_pos": self.input_pos[None, : self.n_tokens],
174-
"mask": self.causal_mask[None, : self.n_tokens],
184+
"input_pos": contiguous_input_pos,
185+
"mask": contiguous_mask,
175186
"encoder_input": torch.randn(
176187
1, self.encoder_max_seq_len, self.model_.dim, dtype=self.dtype
177188
),

0 commit comments

Comments
 (0)