@@ -162,7 +162,8 @@ def rescale_and_normalize(
162162
163163 def preprocess (self , image ) -> Tuple [torch .Tensor , torch .Tensor ]:
164164 image_arr = np .asarray (image , dtype = np .uint8 )
165- image_data = torch .from_numpy (image_arr ).permute (2 , 0 , 1 ).contiguous ()
165+ image_data = torch .from_numpy (image_arr ).permute (2 , 0 , 1 )
166+
166167 grouped_images , grouped_images_index = group_images_by_shape (
167168 [image_data ], disable_grouping = self .disable_grouping
168169 )
@@ -183,27 +184,39 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
183184 interpolation = self .interpolation ,
184185 )
185186 resized_images_grouped [shape ] = stacked_images
187+
188+ grouped_images = None
186189 resized_images = reorder_images (resized_images_grouped , grouped_images_index )
190+ resized_images_grouped = None
187191
188- # Group images by size for further processing
189- # Needed in case do_resize is False, or resize returns images with different sizes
190192 grouped_images , grouped_images_index = group_images_by_shape (
191193 resized_images , disable_grouping = self .disable_grouping
192194 )
195+ resized_images = None
196+
193197 processed_images_grouped = {}
194198 processed_grids = {}
199+
195200 for shape , stacked_images in grouped_images .items ():
201+ stacked_images = stacked_images .to ("cuda" , non_blocking = True )
202+
196203 resized_height , resized_width = stacked_images .shape [- 2 :]
197- # Fused rescale and normalize
204+
198205 patches = self .rescale_and_normalize (
199- stacked_images , self .do_rescale , self .rescale_factor , self .do_normalize , self .image_mean , self .image_std
206+ stacked_images ,
207+ self .do_rescale ,
208+ self .rescale_factor ,
209+ self .do_normalize ,
210+ self .image_mean ,
211+ self .image_std ,
200212 )
201213 if patches .ndim == 4 :
202- # add a temporal dimension if we have images
203214 patches = patches .unsqueeze (1 )
215+
204216 if patches .shape [1 ] % self .temporal_patch_size != 0 :
205217 repeats = patches [:, - 1 :].repeat (1 , self .temporal_patch_size - 1 , 1 , 1 , 1 )
206218 patches = torch .cat ([patches , repeats ], dim = 1 )
219+
207220 batch_size , grid_t , channel = patches .shape [:3 ]
208221 grid_t = grid_t // self .temporal_patch_size
209222 grid_h , grid_w = resized_height // self .patch_size , resized_width // self .patch_size
@@ -224,8 +237,7 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
224237 .permute (0 , 1 , 4 , 7 , 5 , 8 , 3 , 2 , 6 , 9 )
225238 .contiguous ()
226239 )
227- # Reorder dimensions to group grid and patch information for subsequent flattening.
228- # (batch, grid_t, grid_h, grid_w, merge_h, merge_w, channel, temp_patch_size, patch_h, patch_w)
240+
229241 flatten_patches = patches .view (
230242 batch_size ,
231243 grid_t * grid_h * grid_w ,
@@ -235,9 +247,12 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
235247 processed_images_grouped [shape ] = flatten_patches
236248 processed_grids [shape ] = [[grid_t , grid_h , grid_w ]] * batch_size
237249
250+ grouped_images = None
251+
238252 processed_images = reorder_images (processed_images_grouped , grouped_images_index )
239253 processed_grids = reorder_images (processed_grids , grouped_images_index )
240- pixel_values = torch .cat (processed_images , dim = 0 ) # (num_patches_total, C*T*ps*ps)
254+
255+ pixel_values = torch .cat (processed_images , dim = 0 )
241256 image_grid_thw = torch .as_tensor (processed_grids )
242257
243258 return pixel_values , image_grid_thw
0 commit comments