1515
1616import os
1717import json
18+ import time
1819from PIL import Image
1920from io import BytesIO
2021from typing import List
@@ -67,7 +68,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6768 hidden_states = hidden_states .view (
6869 - 1 , self .in_channels , self .temporal_patch_size , self .patch_size , self .patch_size
6970 )
71+ # num_patches = hidden_states.shape[0]
72+ # print(f"num_patches is {num_patches}")
73+ # torch.cuda.synchronize()
74+ # time0 = time.perf_counter()
7075 hidden_states = self .proj (hidden_states ).view (- 1 , self .embed_dim )
76+ # torch.cuda.synchronize()
77+ # print(f"patch embed time is {time.perf_counter()-time0}")
7178 return hidden_states
7279
7380
@@ -194,6 +201,39 @@ def _init_datatype(self):
194201 raise ValueError (f"Unsupport datatype { self .data_type } !" )
195202 return
196203
204+ def concat_img_embed_and_deepstack_features (self , image_embed , deepstack_feature_lists , valid_ids ):
205+ # input: image_embed: [img_embed1, img_embed2, img_embed3]
206+ # deepstack_feature_lists:[df1-1, df1-2, df1-3,
207+ # df2-1, df2-2, df2-3,
208+ # df3-1, df3-2, df3-3]
209+ # valid_ids:[[start_1, end_1], [start_2, end_2], [start_3, end_3]]
210+ #
211+ # return: all_img_embeds_ds: [img_embed1, df1-1, df1-2, df1-3,
212+ # img_embed2, df2-1, df2-2, df2-3,
213+ # img_embed3, df3-1, df3-2, df3-3]
214+ # valid_ids:[[start_1, end_1], [start_2, end_2], [start_3, end_3]] # image_embed的start和end
215+ all_chunks = []
216+ new_valid_ids = []
217+
218+ row_offset = 0
219+
220+ for start , end in valid_ids :
221+ hs_i = image_embed [start :end ]
222+ ds_i_list = [feat [start :end ] for feat in deepstack_feature_lists ]
223+
224+ combined_i = torch .cat ([hs_i , * ds_i_list ], dim = 0 )
225+
226+ new_start = row_offset
227+ new_end = row_offset + combined_i .size (0 )
228+ new_valid_ids .append ([new_start , new_end ])
229+
230+ all_chunks .append (combined_i )
231+
232+ row_offset += new_end
233+
234+ all_img_embeds_ds = torch .cat (all_chunks , dim = 0 )
235+ return all_img_embeds_ds , new_valid_ids
236+
197237 def load_model (self , weight_dir ):
198238
199239 processor_config_path = os .path .join (weight_dir , "preprocessor_config.json" )
@@ -320,21 +360,17 @@ def fast_pos_embed_interpolate(self, grid_thw):
320360
321361 def forward (self , hidden_states : torch .Tensor , grid_thw : torch .Tensor , ** kwargs ) -> torch .Tensor :
322362 hidden_states = self .patch_embed (hidden_states )
323-
324363 pos_embeds = self .fast_pos_embed_interpolate (grid_thw )
325364 hidden_states = hidden_states + pos_embeds
326-
327365 rotary_cos , rotary_sin = self .rot_pos_emb (grid_thw )
328366 rotary_cos = rotary_cos .to ("cuda" , non_blocking = True )
329367 rotary_sin = rotary_sin .to ("cuda" , non_blocking = True )
330-
331368 cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ], grid_thw [:, 0 ]).cumsum (
332369 dim = 0 ,
333370 dtype = torch .int32 ,
334371 )
335372 cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), value = 0 ).to ("cuda" , non_blocking = True )
336373 max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
337-
338374 deepstack_feature_lists = []
339375 for layer_num , blk in enumerate (self .blocks ):
340376 hidden_states = blk (
@@ -349,6 +385,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs)
349385 hidden_states
350386 )
351387 deepstack_feature_lists .append (deepstack_feature )
388+ # print(f"ds time is {time.perf_counter()-time0}")
352389
353390 hidden_states = self .merger (hidden_states )
354391
@@ -391,7 +428,9 @@ def encode(self, images: List[ImageItem]):
391428
392429 pixel_values = imgs .to ("cuda" , dtype = self .data_type , non_blocking = True )
393430 image_grid_thw = grid_thw .to ("cuda" , non_blocking = True )
431+ img_embeds , deepstack_feature_lists = self .forward (pixel_values , grid_thw = image_grid_thw )
432+ all_img_embeds_df , valid_ids = self .concat_img_embed_and_deepstack_features (
433+ img_embeds , deepstack_feature_lists , valid_ids
434+ )
394435
395- all_img_embeds , deepstack_feature_lists = self .forward (pixel_values , grid_thw = image_grid_thw )
396-
397- return all_img_embeds , uuids , valid_ids , deepstack_feature_lists
436+ return all_img_embeds_df , uuids , valid_ids
0 commit comments