File tree Expand file tree Collapse file tree 1 file changed +13
-2
lines changed
examples/models/llama3_2_vision/text_decoder Expand file tree Collapse file tree 1 file changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -168,11 +168,22 @@ def get_example_inputs(self):
168168 def get_example_kwarg_inputs (self ):
169169 # For export we must use the prefill versions of the
170170 # causal mask and input_pos.
171+
172+ # Make input_pos and mask contiguous in memory.
173+ input_pos = self .input_pos [None , : self .n_tokens ]
174+ mask = self .causal_mask [None , : self .n_tokens ]
175+ contiguous_input_pos = torch .empty_like (
176+ input_pos , memory_format = torch .contiguous_format
177+ )
178+ contiguous_input_pos .data .copy_ (input_pos .data )
179+ contiguous_mask = torch .empty_like (mask , memory_format = torch .contiguous_format )
180+ contiguous_mask .data .copy_ (mask .data )
181+
171182 # Hardcoding # of tiles to be 2. image tokens per tile is 1601.
172183 if self .use_kv_cache :
173184 return {
174- "input_pos" : self . input_pos [ None , : self . n_tokens ] ,
175- "mask" : self . causal_mask [ None , : self . n_tokens ] ,
185+ "input_pos" : contiguous_input_pos ,
186+ "mask" : contiguous_mask ,
176187 "encoder_input" : torch .randn (
177188 1 , self .encoder_max_seq_len , self .model_ .dim , dtype = self .dtype
178189 ),
You can’t perform that action at this time.
0 commit comments