@@ -93,6 +93,13 @@ def initialize_vision_modules(self, model_args, fsdp=None):
9393 self .config .mm_vision_select_feature = mm_vision_select_feature
9494 self .config .mm_patch_merge_type = mm_patch_merge_type
9595
96+ if not hasattr (self .config , 'add_faster_video' ):
97+ if model_args .add_faster_video :
98+ embed_std = 1 / torch .sqrt (torch .tensor (self .config .hidden_size , dtype = self .dtype ))
99+ self .faster_token = nn .Parameter (
100+ torch .randn (self .config .hidden_size , dtype = self .dtype ) * embed_std
101+ )
102+
96103 if getattr (self , "mm_projector" , None ) is None :
97104 self .mm_projector = build_vision_projector (self .config , vision_cfg = vision_tower .config )
98105
@@ -160,19 +167,19 @@ def get_model(self):
160167 def get_vision_tower (self ):
161168 return self .get_model ().get_vision_tower ()
162169
163- def get_2dPool (self , image_feature ):
170+ def get_2dPool (self , image_feature , stride = 2 ):
164171 height = width = self .get_vision_tower ().num_patches_per_side
165172 num_frames , num_tokens , num_dim = image_feature .shape
166173 image_feature = image_feature .view (num_frames , height , width , - 1 )
167174 image_feature = image_feature .permute (0 , 3 , 1 , 2 ).contiguous ()
168175 # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
169176 if self .config .mm_spatial_pool_mode == "average" :
170- image_feature = nn .functional .avg_pool2d (image_feature , self . config . mm_spatial_pool_stride )
177+ image_feature = nn .functional .avg_pool2d (image_feature , stride )
171178 elif self .config .mm_spatial_pool_mode == "max" :
172- image_feature = nn .functional .max_pool2d (image_feature , self . config . mm_spatial_pool_stride )
179+ image_feature = nn .functional .max_pool2d (image_feature , stride )
173180 elif self .config .mm_spatial_pool_mode == "bilinear" :
174181 height , weight = image_feature .shape [2 :]
175- scaled_shape = [math .ceil (height / 2 ), math .ceil (weight / 2 )]
182+ scaled_shape = [math .ceil (height / stride ), math .ceil (weight / stride )]
176183 image_feature = nn .functional .interpolate (image_feature , size = scaled_shape , mode = 'bilinear' )
177184
178185 else :
@@ -191,13 +198,54 @@ def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=N
191198 videos_or_images_features = self .get_model ().get_vision_tower ()(videos_or_images )
192199 per_videos_or_images_features = torch .split (videos_or_images_features , split_sizes , dim = 0 ) # tuple, (dim_1, 576, 4096)
193200 all_videos_or_images_features = []
201+ all_faster_video_features = []
202+ cur_mm_spatial_pool_stride = self .config .mm_spatial_pool_stride
194203
195204 for idx , feat in enumerate (per_videos_or_images_features ):
205+
196206 feat = self .get_model ().mm_projector (feat )
197- if idx in video_idx_in_batch :
198- feat = self .get_2dPool (feat )
199- all_videos_or_images_features .append (feat )
200- return all_videos_or_images_features
207+ faster_video_feature = 0
208+ slower_img_feat = 0
209+ if idx in video_idx_in_batch and cur_mm_spatial_pool_stride > 1 :
210+ slower_img_feat = self .get_2dPool (feat ,cur_mm_spatial_pool_stride )
211+ if self .config .add_faster_video :
212+ cur_mm_spatial_pool_stride = cur_mm_spatial_pool_stride * 2
213+ faster_video_feature = self .get_2dPool (feat ,cur_mm_spatial_pool_stride )
214+ if slower_img_feat is not 0 :
215+ all_videos_or_images_features .append (slower_img_feat )
216+ else :
217+ all_videos_or_images_features .append (feat )
218+ all_faster_video_features .append (faster_video_feature )
219+ return all_videos_or_images_features ,all_faster_video_features
220+
221+ def add_token_per_grid (self , image_feature ):
222+ resize_h = int (math .sqrt (image_feature .shape [1 ]))
223+ num_frames = image_feature .shape [0 ]
224+ feature_dim = image_feature .shape [- 1 ]
225+
226+ image_feature = image_feature .view (num_frames , 1 , resize_h , resize_h , - 1 )
227+ image_feature = image_feature .permute (4 , 0 , 2 , 1 , 3 ).contiguous ()
228+ image_feature = image_feature .flatten (1 , 2 ).flatten (2 , 3 )
229+ image_feature = torch .cat ((image_feature , self .model .image_newline [:, None , None ].expand (* image_feature .shape [:- 1 ], 1 ).to (image_feature .device )), dim = - 1 )
230+ if self .config .add_faster_video :
231+ # import pdb; pdb.set_trace()
232+ # (3584, 832, 14) -> (3584, 64, 13, 14)
233+ image_feature = image_feature .view (feature_dim , num_frames ,resize_h , - 1 )
234+ # (3584, 64, 13, 14) -> (64, 13, 14, 3584)
235+ image_feature = image_feature .permute (1 , 2 , 3 , 0 ).contiguous ()
236+ # (64, 13, 14, 3584) -> (64, 13*14, 3584)
237+ image_feature = image_feature .flatten (1 , 2 )
238+ # import pdb; pdb.set_trace()
239+ return image_feature
240+ # import pdb; pdb.set_trace()
241+ image_feature = image_feature .flatten (1 , 2 ).transpose (0 , 1 )
242+ return image_feature
243+
244+ def add_token_per_frame (self , image_feature ):
245+ image_feature = image_feature .permute (2 , 0 , 1 ).contiguous ()
246+ image_feature = torch .cat ((image_feature , self .model .image_newline [:, None , None ].expand (* image_feature .shape [:- 1 ], 1 ).to (image_feature .device )), dim = - 1 )
247+ image_feature = image_feature .permute (1 , 2 , 0 ).contiguous ()
248+ return image_feature
201249
202250 def prepare_inputs_labels_for_multimodal (self , input_ids , position_ids , attention_mask , past_key_values , labels , images , modalities = ["image" ], image_sizes = None ):
203251 vision_tower = self .get_vision_tower ()
@@ -224,6 +272,7 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
224272 concat_images = torch .cat ([image for image in images_list ], dim = 0 )
225273 split_sizes = [image .shape [0 ] for image in images_list ]
226274 encoded_image_features = self .encode_images (concat_images )
275+ # image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
227276
228277 # This is a list, each element is [num_images, patch * patch, dim]
229278 # rank_print(f"Concat images : {concat_images.shape}")
@@ -239,6 +288,7 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
239288 # image_features = torch.split(image_features, split_sizes, dim=0)
240289 mm_patch_merge_type = getattr (self .config , "mm_patch_merge_type" , "flat" )
241290 image_aspect_ratio = getattr (self .config , "image_aspect_ratio" , "square" )
291+ mm_newline_position = getattr (self .config , "mm_newline_position" , "one_token" )
242292
243293 if mm_patch_merge_type == "flat" :
244294 image_features = [x .flatten (0 , 1 ) for x in image_features ]
@@ -253,13 +303,44 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
253303 # rank0_print("At least we are reaching here")
254304 if image_idx in video_idx_in_batch : # video operations
255305 # rank0_print("Video")
256- if "unpad" in mm_patch_merge_type :
257- # image_feature = image_feature.permute(2, 0, 1).contiguous()
258- # image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
259- # image_feature = image_feature.permute(1, 2, 0).contiguous()
306+ if mm_newline_position == "grid" :
307+ # Grid-wise
308+ image_feature = self .add_token_per_grid (image_feature )
309+ if self .config .add_faster_video :
310+ faster_video_feature = self .add_token_per_grid (all_faster_video_features [image_idx ])
311+ # Add a token for each frame
312+ concat_slow_fater_token = []
313+ # import pdb; pdb.set_trace()
314+ for _ in range (image_feature .shape [0 ]):
315+ if _ % self .config .faster_token_stride == 0 :
316+ concat_slow_fater_token .append (torch .cat ((image_feature [_ ], self .model .faster_token [None ].to (image_feature .device )), dim = 0 ))
317+ else :
318+ concat_slow_fater_token .append (torch .cat ((faster_video_feature [_ ], self .model .faster_token [None ].to (image_feature .device )), dim = 0 ))
319+ # import pdb; pdb.set_trace()
320+ image_feature = torch .cat (concat_slow_fater_token )
321+
322+ # print("!!!!!!!!!!!!")
323+
324+ new_image_features .append (image_feature )
325+ elif mm_newline_position == "frame" :
326+ # Frame-wise
327+ image_feature = self .add_token_per_frame (image_feature )
328+
329+ new_image_features .append (image_feature .flatten (0 , 1 ))
330+
331+ elif mm_newline_position == "one_token" :
332+ # one-token
260333 image_feature = image_feature .flatten (0 , 1 )
261- image_feature = torch .cat ((image_feature , self .model .image_newline [None ].to (image_feature .device )), dim = 0 )
262-
334+ if 'unpad' in mm_patch_merge_type :
335+ image_feature = torch .cat ((
336+ image_feature ,
337+ self .model .image_newline [None ].to (image_feature .device )
338+ ), dim = 0 )
339+ new_image_features .append (image_feature )
340+ elif mm_newline_position == "no_token" :
341+ new_image_features .append (image_feature .flatten (0 , 1 ))
342+ else :
343+ raise ValueError (f"Unexpected mm_newline_position: { mm_newline_position } " )
263344 elif image_feature .shape [0 ] > 1 : # multi patches and multi images operations
264345 # rank0_print("Single-images")
265346 base_image_feature = image_feature [0 ]
@@ -316,12 +397,13 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
316397 pass
317398 else :
318399 image_feature = torch .cat ((base_image_feature , image_feature ), dim = 0 )
400+ new_image_features .append (image_feature )
319401 else : # single image operations
320402 image_feature = image_feature [0 ]
321403 if "unpad" in mm_patch_merge_type :
322404 image_feature = torch .cat ((image_feature , self .model .image_newline [None ]), dim = 0 )
323405
324- new_image_features .append (image_feature )
406+ new_image_features .append (image_feature )
325407 image_features = new_image_features
326408 else :
327409 raise ValueError (f"Unexpected mm_patch_merge_type: { self .config .mm_patch_merge_type } " )
@@ -506,4 +588,4 @@ def initialize_vision_tokenizer(self, model_args, tokenizer):
506588 for p in self .get_input_embeddings ().parameters ():
507589 p .requires_grad = False
508590 for p in self .get_output_embeddings ().parameters ():
509- p .requires_grad = False
591+ p .requires_grad = False
0 commit comments