Skip to content

Commit 2a292c3

Browse files
authored
Make inputs actually contiguously laid out in memory (#7072)
1 parent 97a8a89 commit 2a292c3

File tree

1 file changed

+13
-2
lines changed
  • examples/models/llama3_2_vision/text_decoder

1 file changed

+13
-2
lines changed

examples/models/llama3_2_vision/text_decoder/model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,22 @@ def get_example_inputs(self):
168168
def get_example_kwarg_inputs(self):
169169
# For export we must use the prefill versions of the
170170
# causal mask and input_pos.
171+
172+
# Make input_pos and mask contiguous in memory.
173+
input_pos = self.input_pos[None, : self.n_tokens]
174+
mask = self.causal_mask[None, : self.n_tokens]
175+
contiguous_input_pos = torch.empty_like(
176+
input_pos, memory_format=torch.contiguous_format
177+
)
178+
contiguous_input_pos.data.copy_(input_pos.data)
179+
contiguous_mask = torch.empty_like(mask, memory_format=torch.contiguous_format)
180+
contiguous_mask.data.copy_(mask.data)
181+
171182
# Hardcoding # of tiles to be 2. image tokens per tile is 1601.
172183
if self.use_kv_cache:
173184
return {
174-
"input_pos": self.input_pos[None, : self.n_tokens],
175-
"mask": self.causal_mask[None, : self.n_tokens],
185+
"input_pos": contiguous_input_pos,
186+
"mask": contiguous_mask,
176187
"encoder_input": torch.randn(
177188
1, self.encoder_max_seq_len, self.model_.dim, dtype=self.dtype
178189
),

0 commit comments

Comments
 (0)