66
77import torch
88import torch .nn as nn
9+ import torch .nn .functional as F
910import numpy as np
1011from transformers import (
1112 CLIPVisionModel ,
1415)
1516
1617from .configuration_mineru2 import Mineru2QwenConfig
17- from .image_processing_mineru2 import Mineru2ImageProcessor , expand2square , process_anyres_image
18+ from .image_processing_mineru2 import (
19+ Mineru2ImageProcessor ,
20+ expand2square ,
21+ process_anyres_image ,
22+ get_anyres_image_grid_shape ,
23+ )
1824
1925from lightllm .server .multimodal_params import ImageItem
2026from lightllm .server .embed_cache .utils import read_shm , get_shm_name_data
@@ -179,7 +185,9 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
179185 uuids .append (img .uuid )
180186 image_data = read_shm (get_shm_name_data (img .uuid ))
181187 image_data = Image .open (BytesIO (image_data )).convert ("RGB" )
182- if image_aspect_ratio == "pad" :
188+ # 多图/视频强制 pad,单图才允许 anyres
189+ force_pad = len (images ) > 1
190+ if image_aspect_ratio == "pad" or force_pad :
183191 image_proc = expand2square (image_data , tuple (int (x * 255 ) for x in self .image_processor .image_mean ))
184192 t = self .image_processor .preprocess (image_proc , return_tensors = "pt" )["pixel_values" ]
185193 elif image_aspect_ratio and (image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio ):
@@ -194,16 +202,18 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
194202 elif t .ndim == 3 :
195203 t = t .unsqueeze (0 )
196204
197- # 对齐实际视图数 K 与期望 token(可能是 K 或 K*patch_len)
198- expected_token = img .token_num if getattr (img , "token_num" , None ) is not None else None
205+ # 对齐实际视图数 K 与期望视图数(anyres: Nx*Ny+1;否则:1)
199206 actual_k = t .shape [0 ]
200- if expected_token is None or expected_token <= 0 :
201- expected_views = actual_k
207+ if (
208+ image_aspect_ratio and (image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio )
209+ ) and not force_pad :
210+ crop_size = self .image_processor .crop_size ["height" ]
211+ grid_w , grid_h = get_anyres_image_grid_shape (
212+ (img .image_w , img .image_h ), image_grid_pinpoints , crop_size
213+ )
214+ expected_views = int (grid_w * grid_h + 1 )
202215 else :
203- if expected_token >= patch_len and expected_token % patch_len == 0 :
204- expected_views = expected_token // patch_len
205- else :
206- expected_views = expected_token
216+ expected_views = 1
207217 if actual_k != expected_views :
208218 if actual_k % expected_views == 0 :
209219 factor = actual_k // expected_views
@@ -219,26 +229,86 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
219229 pad = t [- 1 :].repeat (expected_views - actual_k , 1 , 1 , 1 )
220230 t = torch .cat ([t , pad ], dim = 0 )
221231 img_tensors .append (t )
222- # 最终视图数 K
223- final_views = t .shape [0 ]
224- # 对齐 patch 序列后的总 token 数
225- img .token_num = final_views * patch_len
226232 else :
227233 raise Exception ("Unsupport input types: {} for {}" .format (type (img ), img ))
228234
229- # 本图对应的 token 数(视图 * patch_len)
230- if isinstance (img_tensors [- 1 ], torch .Tensor ) and img_tensors [- 1 ].dim () == 4 :
231- cur_num = img_tensors [- 1 ].shape [0 ] * patch_len
232- else :
233- cur_num = patch_len
234- valid_ids .append ([valid_id , valid_id + cur_num ])
235- valid_id += cur_num
235+ # 暂不累加 valid_ids,待完成重组后依据真实长度填写
236236
237237 if len (img_tensors ) <= 0 :
238238 return None , [], []
239239 # 保证全部为4维后拼接
240240 img = torch .cat (img_tensors , dim = 0 )
241241 img = img .cuda ()
242+ # 提取所有视图的 patch 序列嵌入(views * patch_len, hidden)
242243 all_img_embeds = self .forward (img )
243244
244- return all_img_embeds , uuids , valid_ids
245+ # 将每张图的视图嵌入进行 spatial+unpad(+anyres_max) 重组,并追加换行列
246+ new_embeds : List [torch .Tensor ] = []
247+ cur = 0
248+ for i , img in enumerate (images ):
249+ # 计算本图视图数
250+ t = img_tensors [i ]
251+ K = t .shape [0 ]
252+ # 取出本图的所有 view 的 patch 序列嵌入
253+ tokens_len = K * patch_len
254+ cur_views_embeds = all_img_embeds [cur : cur + tokens_len ]
255+ cur += tokens_len
256+
257+ # 非 anyres 或多图/视频强制 pad:直接使用展平序列(K 通常为 1)
258+ force_pad = len (images ) > 1
259+ aspect = getattr (self .image_processor , "image_aspect_ratio" , None )
260+ if not aspect or ("anyres" not in str (aspect )) or force_pad or K <= 1 :
261+ seq = cur_views_embeds
262+ new_embeds .append (seq )
263+ # 记录区间
264+ valid_ids .append ([valid_id , valid_id + seq .shape [0 ]])
265+ valid_id += seq .shape [0 ]
266+ continue
267+
268+ # anyres 单图路径:
269+ # 切分 base 视图与其余视图
270+ base_feature = cur_views_embeds [:patch_len ]
271+ rest = cur_views_embeds [patch_len :]
272+ # (K-1, patch_len, hidden)
273+ hidden = rest .shape [- 1 ]
274+ rest = rest .view (K - 1 , patch_len , hidden )
275+
276+ # 计算 Nx, Ny
277+ crop_size = self .image_processor .crop_size ["height" ]
278+ grid_w , grid_h = get_anyres_image_grid_shape ((img .image_w , img .image_h ), image_grid_pinpoints , crop_size )
279+ # (Ny, Nx, patch_side, patch_side, hidden)
280+ rest = rest .view (grid_w * grid_h , patch_side , patch_side , hidden )
281+ rest = rest .view (grid_h , grid_w , patch_side , patch_side , hidden )
282+ # (hidden, Ny, patch_side, Nx, patch_side) -> (hidden, H, W)
283+ rest = rest .permute (4 , 0 , 2 , 1 , 3 ).contiguous ()
284+ H = grid_h * patch_side
285+ W = grid_w * patch_side
286+ rest = rest .view (hidden , H , W )
287+
288+ # anyres_max 下采样
289+ m = re .search (r"anyres_max_(\d+)" , str (aspect ))
290+ if m is not None :
291+ max_num_patches = int (m .group (1 ))
292+ times = (H * W ) / (max_num_patches * patch_len )
293+ if times > 1.1 :
294+ scale = (int (H // (times ** 0.5 )), int (W // (times ** 0.5 )))
295+ rest = F .interpolate (rest .unsqueeze (0 ), size = scale , mode = "bilinear" , align_corners = False )[0 ]
296+ H , W = rest .shape [1 ], rest .shape [2 ]
297+
298+ # 追加换行列(列数+1),换行列取 0 向量占位
299+ newline_col = torch .zeros ((hidden , H , 1 ), device = rest .device , dtype = rest .dtype )
300+ rest = torch .cat ([rest , newline_col ], dim = 2 ) # (hidden, H, W+1)
301+ # 展平成 (H*(W+1), hidden)
302+ rest = rest .flatten (1 , 2 ).transpose (0 , 1 ).contiguous ()
303+
304+ # 拼接 base + 其余
305+ seq = torch .cat ([base_feature , rest ], dim = 0 )
306+ new_embeds .append (seq )
307+
308+ # 记录区间
309+ valid_ids .append ([valid_id , valid_id + seq .shape [0 ]])
310+ valid_id += seq .shape [0 ]
311+
312+ # 拼接所有图的重组后嵌入
313+ all_new = torch .cat (new_embeds , dim = 0 )
314+ return all_new , uuids , valid_ids
0 commit comments