@@ -199,6 +199,22 @@ def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=N
199199 all_videos_or_images_features .append (feat )
200200 return all_videos_or_images_features
201201
202+ def add_token_per_grid (self , image_feature ):
203+ resize_h = int (math .sqrt (image_feature .shape [1 ]))
204+ num_frames = image_feature .shape [0 ]
205+ image_feature = image_feature .view (num_frames , 1 , resize_h , resize_h , - 1 )
206+ image_feature = image_feature .permute (4 , 0 , 2 , 1 , 3 ).contiguous ()
207+ image_feature = image_feature .flatten (1 , 2 ).flatten (2 , 3 )
208+ image_feature = torch .cat ((image_feature , self .model .image_newline [:, None , None ].expand (* image_feature .shape [:- 1 ], 1 ).to (image_feature .device )), dim = - 1 )
209+ image_feature = image_feature .flatten (1 , 2 ).transpose (0 , 1 )
210+ return image_feature
211+
212+ def add_token_per_frame (self , image_feature ):
213+ image_feature = image_feature .permute (2 , 0 , 1 ).contiguous ()
214+ image_feature = torch .cat ((image_feature , self .model .image_newline [:, None , None ].expand (* image_feature .shape [:- 1 ], 1 ).to (image_feature .device )), dim = - 1 )
215+ image_feature = image_feature .permute (1 , 2 , 0 ).contiguous ()
216+ return image_feature
217+
202218 def prepare_inputs_labels_for_multimodal (self , input_ids , position_ids , attention_mask , past_key_values , labels , images , modalities = ["image" ], image_sizes = None ):
203219 vision_tower = self .get_vision_tower ()
204220 # rank_print(modalities)
@@ -253,12 +269,31 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
253269 # rank0_print("At least we are reaching here")
254270 if image_idx in video_idx_in_batch : # video operations
255271 # rank0_print("Video")
256- if "unpad" in mm_patch_merge_type :
257- # image_feature = image_feature.permute(2, 0, 1).contiguous()
258- # image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
259- # image_feature = image_feature.permute(1, 2, 0).contiguous()
272+ if self .config .mm_newline_position == "grid" :
273+ # Grid-wise
274+ image_feature = self .add_token_per_grid (image_feature )
275+
276+ new_image_features .append (image_feature )
277+ elif self .config .mm_newline_position == "frame" :
278+ # Frame-wise
279+ image_feature = self .add_token_per_frame (image_feature )
280+
281+ new_image_features .append (image_feature .flatten (0 , 1 ))
282+
283+ elif self .config .mm_newline_position == "one_token" :
284+ # one-token
260285 image_feature = image_feature .flatten (0 , 1 )
261- image_feature = torch .cat ((image_feature , self .model .image_newline [None ].to (image_feature .device )), dim = 0 )
286+ if 'unpad' in mm_patch_merge_type :
287+ image_feature = torch .cat ((
288+ image_feature ,
289+ self .model .image_newline [None ].to (image_feature .device )
290+ ), dim = 0 )
291+ new_image_features .append (image_feature )
292+ elif self .config .mm_newline_position == "no_token" :
293+ new_image_features .append (image_feature .flatten (0 , 1 ))
294+ else :
295+ raise ValueError (f"Unexpected mm_newline_position: { self .config .mm_newline_position } " )
296+
262297
263298 elif image_feature .shape [0 ] > 1 : # multi patches and multi images operations
264299 # rank0_print("Single-images")
0 commit comments