@@ -93,6 +93,13 @@ def initialize_vision_modules(self, model_args, fsdp=None):
9393 self .config .mm_vision_select_feature = mm_vision_select_feature
9494 self .config .mm_patch_merge_type = mm_patch_merge_type
9595
96+ if not hasattr (self .config , 'add_faster_video' ):
97+ if model_args .add_faster_video :
98+ embed_std = 1 / torch .sqrt (torch .tensor (self .config .hidden_size , dtype = self .dtype ))
99+ self .faster_token = nn .Parameter (
100+ torch .randn (self .config .hidden_size , dtype = self .dtype ) * embed_std
101+ )
102+
96103 if getattr (self , "mm_projector" , None ) is None :
97104 self .mm_projector = build_vision_projector (self .config , vision_cfg = vision_tower .config )
98105
@@ -160,19 +167,19 @@ def get_model(self):
160167 def get_vision_tower (self ):
161168 return self .get_model ().get_vision_tower ()
162169
163- def get_2dPool (self , image_feature ):
170+ def get_2dPool (self , image_feature , stride = 2 ):
164171 height = width = self .get_vision_tower ().num_patches_per_side
165172 num_frames , num_tokens , num_dim = image_feature .shape
166173 image_feature = image_feature .view (num_frames , height , width , - 1 )
167174 image_feature = image_feature .permute (0 , 3 , 1 , 2 ).contiguous ()
168175 # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
169176 if self .config .mm_spatial_pool_mode == "average" :
170- image_feature = nn .functional .avg_pool2d (image_feature , self . config . mm_spatial_pool_stride )
177+ image_feature = nn .functional .avg_pool2d (image_feature , stride )
171178 elif self .config .mm_spatial_pool_mode == "max" :
172- image_feature = nn .functional .max_pool2d (image_feature , self . config . mm_spatial_pool_stride )
179+ image_feature = nn .functional .max_pool2d (image_feature , stride )
173180 elif self .config .mm_spatial_pool_mode == "bilinear" :
174181 height , weight = image_feature .shape [2 :]
175- scaled_shape = [math .ceil (height / 2 ), math .ceil (weight / 2 )]
182+ scaled_shape = [math .ceil (height / stride ), math .ceil (weight / stride )]
176183 image_feature = nn .functional .interpolate (image_feature , size = scaled_shape , mode = 'bilinear' )
177184
178185 else :
@@ -191,21 +198,46 @@ def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=N
191198 videos_or_images_features = self .get_model ().get_vision_tower ()(videos_or_images )
192199 per_videos_or_images_features = torch .split (videos_or_images_features , split_sizes , dim = 0 ) # tuple, (dim_1, 576, 4096)
193200 all_videos_or_images_features = []
201+ all_faster_video_features = []
202+ cur_mm_spatial_pool_stride = self .config .mm_spatial_pool_stride
194203
195204 for idx , feat in enumerate (per_videos_or_images_features ):
205+
196206 feat = self .get_model ().mm_projector (feat )
197- if idx in video_idx_in_batch :
198- feat = self .get_2dPool (feat )
199- all_videos_or_images_features .append (feat )
200- return all_videos_or_images_features
207+ faster_video_feature = 0
208+ slower_img_feat = 0
209+ if idx in video_idx_in_batch and cur_mm_spatial_pool_stride > 1 :
210+ slower_img_feat = self .get_2dPool (feat ,cur_mm_spatial_pool_stride )
211+ if self .config .add_faster_video :
212+ cur_mm_spatial_pool_stride = cur_mm_spatial_pool_stride * 2
213+ faster_video_feature = self .get_2dPool (feat ,cur_mm_spatial_pool_stride )
214+ if slower_img_feat is not 0 :
215+ all_videos_or_images_features .append (slower_img_feat )
216+ else :
217+ all_videos_or_images_features .append (feat )
218+ all_faster_video_features .append (faster_video_feature )
219+ return all_videos_or_images_features ,all_faster_video_features
201220
202221 def add_token_per_grid (self , image_feature ):
203222 resize_h = int (math .sqrt (image_feature .shape [1 ]))
204223 num_frames = image_feature .shape [0 ]
224+ feature_dim = image_feature .shape [- 1 ]
225+
205226 image_feature = image_feature .view (num_frames , 1 , resize_h , resize_h , - 1 )
206227 image_feature = image_feature .permute (4 , 0 , 2 , 1 , 3 ).contiguous ()
207228 image_feature = image_feature .flatten (1 , 2 ).flatten (2 , 3 )
208229 image_feature = torch .cat ((image_feature , self .model .image_newline [:, None , None ].expand (* image_feature .shape [:- 1 ], 1 ).to (image_feature .device )), dim = - 1 )
230+ if self .config .add_faster_video :
231+ # import pdb; pdb.set_trace()
232+ # (3584, 832, 14) -> (3584, 64, 13, 14)
233+ image_feature = image_feature .view (feature_dim , num_frames ,resize_h , - 1 )
234+ # (3584, 64, 13, 14) -> (64, 13, 14, 3584)
235+ image_feature = image_feature .permute (1 , 2 , 3 , 0 ).contiguous ()
236+ # (64, 13, 14, 3584) -> (64, 13*14, 3584)
237+ image_feature = image_feature .flatten (1 , 2 )
238+ # import pdb; pdb.set_trace()
239+ return image_feature
240+ # import pdb; pdb.set_trace()
209241 image_feature = image_feature .flatten (1 , 2 ).transpose (0 , 1 )
210242 return image_feature
211243
@@ -246,6 +278,7 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
246278 concat_images = torch .cat ([image for image in images_list ], dim = 0 )
247279 split_sizes = [image .shape [0 ] for image in images_list ]
248280 encoded_image_features = self .encode_images (concat_images )
281+ # image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
249282
250283 # This is a list, each element is [num_images, patch * patch, dim]
251284 # rank_print(f"Concat images : {concat_images.shape}")
@@ -278,6 +311,20 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
278311 if self .config .mm_newline_position == "grid" :
279312 # Grid-wise
280313 image_feature = self .add_token_per_grid (image_feature )
314+ if self .config .add_faster_video :
315+ faster_video_feature = self .add_token_per_grid (all_faster_video_features [image_idx ])
316+ # Add a token for each frame
317+ concat_slow_fater_token = []
318+ # import pdb; pdb.set_trace()
319+ for _ in range (image_feature .shape [0 ]):
320+ if _ % self .config .faster_token_stride == 0 :
321+ concat_slow_fater_token .append (torch .cat ((image_feature [_ ], self .model .faster_token [None ].to (image_feature .device )), dim = 0 ))
322+ else :
323+ concat_slow_fater_token .append (torch .cat ((faster_video_feature [_ ], self .model .faster_token [None ].to (image_feature .device )), dim = 0 ))
324+ # import pdb; pdb.set_trace()
325+ image_feature = torch .cat (concat_slow_fater_token )
326+
327+ # print("!!!!!!!!!!!!")
281328
282329 new_image_features .append (image_feature )
283330 elif self .config .mm_newline_position == "frame" :
@@ -357,12 +404,13 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
357404 pass
358405 else :
359406 image_feature = torch .cat ((base_image_feature , image_feature ), dim = 0 )
407+ new_image_features .append (image_feature )
360408 else : # single image operations
361409 image_feature = image_feature [0 ]
362410 if "unpad" in mm_patch_merge_type :
363411 image_feature = torch .cat ((image_feature , self .model .image_newline [None ]), dim = 0 )
364412
365- new_image_features .append (image_feature )
413+ new_image_features .append (image_feature )
366414 image_features = new_image_features
367415 else :
368416 raise ValueError (f"Unexpected mm_patch_merge_type: { self .config .mm_patch_merge_type } " )
0 commit comments