@@ -125,8 +125,52 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
125125 elif t .ndim == 3 :
126126 print (f"[debug] mineru2_visual unsqueeze t.ndim: { t .ndim } , t.shape: { t .shape } " )
127127 t = t .unsqueeze (0 )
128+ # 在修改前记录 manager 分配的 token_num
129+ try :
130+ print (f"[debug] mineru2_visual manager_token_num_before={ img .token_num } uuid={ img .uuid } " )
131+ except Exception :
132+ pass
133+ # 对齐实际 K 与期望 token_num
134+ expected_k = img .token_num if getattr (img , "token_num" , None ) is not None else None
135+ actual_k = t .shape [0 ]
136+ if expected_k is None or expected_k <= 0 :
137+ expected_k = actual_k
138+ print (f"[debug] mineru2_visual expected_k_from_actual uuid={ img .uuid } expected_k={ expected_k } " )
139+ if actual_k != expected_k :
140+ if actual_k % expected_k == 0 :
141+ factor = actual_k // expected_k
142+ print (
143+ f"[debug] mineru2_visual down_aggregate uuid={ img .uuid } "
144+ f" actual_k={ actual_k } expected_k={ expected_k } factor={ factor } "
145+ )
146+ t = t .view (expected_k , factor , t .shape [1 ], t .shape [2 ], t .shape [3 ]).mean (dim = 1 )
147+ elif expected_k % actual_k == 0 :
148+ factor = expected_k // actual_k
149+ print (
150+ f"[debug] mineru2_visual up_repeat uuid={ img .uuid } "
151+ f" actual_k={ actual_k } expected_k={ expected_k } factor={ factor } "
152+ )
153+ t = t .repeat_interleave (repeats = factor , dim = 0 )
154+ else :
155+ k = min (actual_k , expected_k )
156+ print (
157+ f"[debug] mineru2_visual fallback_slice uuid={ img .uuid } "
158+ f" actual_k={ actual_k } expected_k={ expected_k } k={ k } "
159+ )
160+ if actual_k >= expected_k :
161+ t = t [:expected_k ]
162+ else :
163+ # pad by repeating last
164+ pad = t [- 1 :].repeat (expected_k - actual_k , 1 , 1 , 1 )
165+ t = torch .cat ([t , pad ], dim = 0 )
128166 img_tensors .append (t )
129- img .token_num = t .shape [0 ]
167+ # 最终 K
168+ final_k = t .shape [0 ]
169+ img .token_num = final_k
170+ print (
171+ f"[debug] mineru2_visual actual_k={ actual_k } "
172+ f"expected_k={ expected_k } final_k={ final_k } uuid={ img .uuid } "
173+ )
130174 else :
131175 raise Exception ("Unsupport input types: {} for {}" .format (type (img ), img ))
132176
@@ -136,6 +180,10 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
136180 else 1
137181 )
138182 valid_ids .append ([valid_id , valid_id + cur_num ])
183+ print (
184+ f"[debug] mineru2_visual valid_ids_append uuid={ img .uuid } "
185+ f" range=({ valid_id } ,{ valid_id + cur_num } ) cur_num={ cur_num } "
186+ )
139187 valid_id += cur_num
140188
141189 if len (img_tensors ) <= 0 :
@@ -144,5 +192,6 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
144192 img = torch .cat (img_tensors , dim = 0 )
145193 img = img .cuda ()
146194 all_img_embeds = self .forward (img )
195+ print (f"[debug] mineru2_visual all_img_embeds.shape={ tuple (all_img_embeds .shape )} " f"total_K={ img .shape [0 ]} " )
147196
148197 return all_img_embeds , uuids , valid_ids
0 commit comments