File tree Expand file tree Collapse file tree 1 file changed +13
-2
lines changed
examples/models/llama3_2_vision/text_decoder Expand file tree Collapse file tree 1 file changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -167,11 +167,22 @@ def get_example_inputs(self):
167167 def get_example_kwarg_inputs (self ):
168168 # For export we must use the prefill versions of the
169169 # causal mask and input_pos.
170+
171+ # Make input_pos and mask contiguous in memory.
172+ input_pos = self .input_pos [None , : self .n_tokens ]
173+ mask = self .causal_mask [None , : self .n_tokens ]
174+ contiguous_input_pos = torch .empty_like (
175+ input_pos , memory_format = torch .contiguous_format
176+ )
177+ contiguous_input_pos .data .copy_ (input_pos .data )
178+ contiguous_mask = torch .empty_like (mask , memory_format = torch .contiguous_format )
179+ contiguous_mask .data .copy_ (mask .data )
180+
170181 # Hardcoding # of tiles to be 2. image tokens per tile is 1601.
171182 if self .use_kv_cache :
172183 return {
173- "input_pos" : self . input_pos [ None , : self . n_tokens ] ,
174- "mask" : self . causal_mask [ None , : self . n_tokens ] ,
184+ "input_pos" : contiguous_input_pos ,
185+ "mask" : contiguous_mask ,
175186 "encoder_input" : torch .randn (
176187 1 , self .encoder_max_seq_len , self .model_ .dim , dtype = self .dtype
177188 ),
You can’t perform that action at this time.
0 commit comments