99import numpy as np
1010from transformers import (
1111 CLIPVisionModel ,
12- CLIPVisionConfig ,
1312 SiglipVisionConfig ,
1413 SiglipVisionModel ,
1514)
3029def build_vision_tower (weight_dir : str , config : Mineru2QwenConfig ):
3130 vision_tower = getattr (config , "mm_vision_tower" , getattr (config , "vision_tower" , "" ))
3231 model_path = os .path .join (weight_dir , vision_tower )
33-
34- def _resolve_path ():
35- if os .path .exists (model_path ):
36- return model_path
37- else :
38- return vision_tower
32+ if not os .path .exists (model_path ):
33+ model_path = vision_tower
3934
4035 if "clip" in vision_tower .lower ():
41- vt_path = _resolve_path ()
42- print (f"[debug] load clip from { vt_path } " )
43- return CLIPVisionModel .from_pretrained (vt_path )
36+ return CLIPVisionModel .from_pretrained (model_path )
4437 elif "siglip" in vision_tower .lower ():
45- vt_path = _resolve_path ()
46- print (f"[debug] load siglip from { vt_path } " )
47- # 方案A:使用配置减层并按该配置实例化模型,再加载权重(忽略不匹配尺寸)
48- cfg = SiglipVisionConfig .from_pretrained (vt_path )
49- old_layers = getattr (cfg , "num_hidden_layers" , None )
38+ cfg = SiglipVisionConfig .from_pretrained (model_path )
5039 cfg .num_hidden_layers = max (0 , cfg .num_hidden_layers - 1 )
5140 cfg .vision_use_head = False
52- model = SiglipVisionModel .from_pretrained (vt_path , config = cfg , ignore_mismatched_sizes = True )
53- try :
54- actual_layers = len (model .vision_model .encoder .layers ) # type: ignore[attr-defined]
55- except Exception :
56- actual_layers = None
57- new_cfg_layers = getattr (getattr (model , "config" , None ), "num_hidden_layers" , None )
58- print (f"[debug] siglip_layers planA old={ old_layers } new_cfg={ new_cfg_layers } actual_module={ actual_layers } " )
41+ model = SiglipVisionModel .from_pretrained (model_path , config = cfg , ignore_mismatched_sizes = True )
5942 return model
6043 else :
6144 raise ValueError (f"Unknown vision tower: { vision_tower } " )
@@ -87,151 +70,61 @@ def __init__(self):
8770 pass
8871
8972 def _load_projector_weights (self , weight_dir : str ):
90- projector_weight_path = os .path .join (weight_dir , "model.safetensors" )
91- print (f"[debug] load projector weights from { projector_weight_path } " )
92-
9373 def assign_linear (linear : nn .Linear , w : torch .Tensor = None , b : torch .Tensor = None ):
9474 if w is not None :
9575 linear .weight .data .copy_ (w .to (dtype = linear .weight .dtype ))
9676 if b is not None and linear .bias is not None :
9777 linear .bias .data .copy_ (b .to (dtype = linear .bias .dtype ))
9878
99- # 收集 projector Linear 模块(顺序即写入顺序)
79+ projector_weight_path = os .path .join (weight_dir , "model.safetensors" )
80+
10081 if isinstance (self .projector , nn .Linear ):
101- print (f"[debug] projector type: { type (self .projector )} " )
10282 linear_modules = [self .projector ]
10383 elif isinstance (self .projector , nn .Sequential ):
104- print (f"[debug] projector type: { type (self .projector )} " )
10584 linear_modules = [m for m in self .projector if isinstance (m , nn .Linear )]
10685 else :
107- print (f"[debug] projector type: { type (self .projector )} " )
10886 raise RuntimeError (f"Unsupported projector type: { type (self .projector )} " )
10987
110- def assign_projector_from_state (sd : dict ) -> bool :
111- # 单层线性:优先直接匹配整体权重;否则回退到首层
112- if len (linear_modules ) == 1 :
113- print ("[debug] projector load: single Linear matched (model.mm_projector.*)" )
114- w = next (
115- (sd [k ] for k in ("model.mm_projector.weight" , "model.mm_projector.linear.weight" ) if k in sd ), None
116- )
117- b = next (
118- (sd [k ] for k in ("model.mm_projector.bias" , "model.mm_projector.linear.bias" ) if k in sd ), None
119- )
120- if w is not None :
121- assign_linear (linear_modules [0 ], w , b )
122- return True
123- # 兜底:若分层存在,仅取第一层
124- w = next (
125- (
126- sd [k ]
127- for k in ("model.mm_projector.0.weight" , "multi_modal_projector.linear_1.weight" )
128- if k in sd
129- ),
130- None ,
131- )
132- b = next (
133- (sd [k ] for k in ("model.mm_projector.0.bias" , "multi_modal_projector.linear_1.bias" ) if k in sd ),
134- None ,
135- )
136- if w is not None :
137- assign_linear (linear_modules [0 ], w , b )
138- print ("[debug] projector load: fallback to first layer for single Linear" )
139- return True
140- return False
141-
142- # 多层(如 mlp2x_gelu):按常见命名逐一匹配
143- layer_key_map = [
144- (
145- 0 ,
146- ("model.mm_projector.0.weight" , "multi_modal_projector.linear_1.weight" ),
147- ("model.mm_projector.0.bias" , "multi_modal_projector.linear_1.bias" ),
148- ),
149- (
150- 1 ,
151- ("model.mm_projector.2.weight" , "multi_modal_projector.linear_2.weight" ),
152- ("model.mm_projector.2.bias" , "multi_modal_projector.linear_2.bias" ),
153- ),
154- ]
155- assigned = 0
156- for idx , w_keys , b_keys in layer_key_map :
157- if idx >= len (linear_modules ):
158- continue
159- w = next ((sd [k ] for k in w_keys if k in sd ), None )
160- b = next ((sd [k ] for k in b_keys if k in sd ), None )
161- if w is not None :
162- assign_linear (linear_modules [idx ], w , b )
163- assigned += 1
164- if assigned > 0 :
165- print (f"[debug] projector load: assigned { assigned } Linear layers" )
166- return True
167- return False
168-
169- def try_load_vision_tower (sd : dict ):
170- # 参考 ref: 去掉前缀 'model.vision_tower.vision_tower.' 进行加载(可选)
171- if not hasattr (self , "vision_tower" ) or not isinstance (
172- self .vision_tower , (CLIPVisionModel , SiglipVisionModel )
173- ):
174- return False , 0
175- vt_prefix = "model.vision_tower.vision_tower."
176- vt_sd = {k [len (vt_prefix ) :]: v for k , v in sd .items () if k .startswith (vt_prefix )}
177- if not vt_sd :
178- return False , 0
179- try :
180- missing , unexpected = self .vision_tower .load_state_dict (vt_sd , strict = False )
181- num = len (vt_sd )
182- print (
183- f"[debug] vision_tower load: keys={ num } "
184- f" missing={ len (missing ) if isinstance (missing , (list , tuple )) else 'n/a' } "
185- f" unexpected={ len (unexpected ) if isinstance (unexpected , (list , tuple )) else 'n/a' } "
186- )
187- return True , num
188- except Exception as e :
189- print (f"[warning] vision_tower load_state_dict failed (strict=False): { e } " )
190- return False , 0
191-
192- # 仅从指定文件加载(优先 .safetensors,fallback 到同名 .bin)
193- if os .path .isfile (projector_weight_path ) and projector_weight_path .endswith (".safetensors" ):
194- try :
195- with safe_open (projector_weight_path , framework = "pt" , device = "cpu" ) as sf :
196- sd = {k : sf .get_tensor (k ) for k in sf .keys ()}
197- except Exception as e :
198- raise RuntimeError (f"Failed to read projector weights: { projector_weight_path } due to { e } " )
199- else :
200- bin_path = (
201- projector_weight_path [:- 14 ] + ".bin"
202- if projector_weight_path .endswith (".safetensors" )
203- else projector_weight_path
204- )
205- if os .path .isfile (bin_path ):
206- try :
207- sd = torch .load (bin_path , map_location = "cpu" )
208- if not isinstance (sd , dict ):
209- raise RuntimeError ("Loaded non-dict state from bin file" )
210- print (f"[debug] fallback load projector weights from { bin_path } " )
211- except Exception as e :
212- raise RuntimeError (f"Failed to read projector weights: { bin_path } due to { e } " )
213- else :
214- raise RuntimeError (f"Projector weight file not found: { projector_weight_path } " )
215-
216- # 加载 projector(必要)
217- projector_loaded = assign_projector_from_state (sd )
218- if not projector_loaded :
219- raise RuntimeError (
220- "Projector weights not found in checkpoint. "
221- "Expected keys like 'model.mm_projector.{0,2}.(weight|bias)' or "
222- "'multi_modal_projector.linear_{1,2}.(weight|bias)' "
223- "or 'model.mm_projector.(weight|bias)'."
224- )
225-
226- # 可选加载 vision_tower
227- vision_loaded , vision_loaded_keys = try_load_vision_tower (sd )
228- if vision_loaded :
229- print (f"[debug] vision_tower weights loaded from checkpoint: keys={ vision_loaded_keys } " )
230- else :
231- print ("[debug] vision_tower weights not found in checkpoint or skipped; keep pretrained weights" )
88+ try :
89+ with safe_open (projector_weight_path , framework = "pt" , device = "cpu" ) as sf :
90+ sd = {k : sf .get_tensor (k ) for k in sf .keys ()}
91+ except Exception as e :
92+ raise RuntimeError (f"Failed to read projector weights: { projector_weight_path } due to { e } " )
93+
94+ # load projector weights
95+ layer_key_map = [
96+ (
97+ 0 ,
98+ ("model.mm_projector.0.weight" , "multi_modal_projector.linear_1.weight" ),
99+ ("model.mm_projector.0.bias" , "multi_modal_projector.linear_1.bias" ),
100+ ),
101+ (
102+ 1 ,
103+ ("model.mm_projector.2.weight" , "multi_modal_projector.linear_2.weight" ),
104+ ("model.mm_projector.2.bias" , "multi_modal_projector.linear_2.bias" ),
105+ ),
106+ ]
107+ for idx , w_keys , b_keys in layer_key_map :
108+ if idx >= len (linear_modules ):
109+ continue
110+ w = next ((sd [k ] for k in w_keys if k in sd ), None )
111+ b = next ((sd [k ] for k in b_keys if k in sd ), None )
112+ if w is not None :
113+ assign_linear (linear_modules [idx ], w , b )
114+
115+ # load vision tower weights
116+ vt_prefix = "model.vision_tower.vision_tower."
117+ vt_sd = {k [len (vt_prefix ) :]: v for k , v in sd .items () if k .startswith (vt_prefix )}
118+ if not vt_sd :
119+ logger .warning ("vision_tower weights not found in checkpoint or skipped; keep pretrained weights" )
120+ return
121+
122+ try :
123+ self .vision_tower .load_state_dict (vt_sd , strict = False )
124+ except Exception as e :
125+ logger .warning (f"vision_tower load_state_dict failed (strict=False): { e } " )
232126
233127 def load_model (self , weight_dir ):
234- print (f"[debug] load vision model: { weight_dir } " )
235128 vision_config = Mineru2QwenConfig .from_pretrained (weight_dir )
236129
237130 self .vision_tower = build_vision_tower (weight_dir , vision_config )
@@ -251,50 +144,26 @@ def cuda(self):
251144 return self
252145
253146 def forward (self , x ) -> torch .Tensor :
254- # 运行时形状与精度/设备检查
255- try :
256- print (f"[debug] mineru2_visual.forward x.shape={ tuple (x .shape )} dtype={ x .dtype } device={ x .device } " )
257- except Exception :
258- pass
259147 vision_out = self .vision_tower (x , output_hidden_states = True )
260148 hiddens = vision_out .hidden_states
261- # hidden_states 数量与 config 层数的关系(一般为 num_layers + 1)
262- try :
263- cfg_layers = getattr (getattr (self .vision_tower , "config" , None ), "num_hidden_layers" , None )
264- eff_layers = len (hiddens ) - 1 if isinstance (hiddens , (list , tuple )) else None
265- print (
266- f"[debug] mineru2_visual.hidden_states len={ len (hiddens )} "
267- f" cfg_layers={ cfg_layers } eff_layers={ eff_layers } "
268- )
269- except Exception :
270- pass
149+
271150 # 对齐ref的“减一层”语义:优先使用倒数第二层;若不可用则回退最后一层
272151 try :
273152 chosen_idx = - 2 if isinstance (hiddens , (list , tuple )) and len (hiddens ) >= 2 else - 1
274153 feat = hiddens [chosen_idx ]
275- print (f"[debug] mineru2_visual.select_layer idx={ chosen_idx } feat.shape={ tuple (feat .shape )} " )
276154 except Exception :
277155 feat = hiddens [- 2 ] if isinstance (hiddens , (list , tuple )) and len (hiddens ) >= 2 else hiddens [- 1 ]
278156 # 切回 patch 序列特征:去除 CLS(若存在),按序列过 projector,再展平为 (views*patch, hidden)
279157 patch_side = self .vision_tower .config .image_size // self .vision_tower .config .patch_size
280158 patch_len = patch_side * patch_side
281159 if feat .shape [1 ] == patch_len + 1 :
282160 feat = feat [:, 1 :, :]
283- print (f"[debug] mineru2_visual.drop_cls patch_len={ patch_len } feat_no_cls.shape={ tuple (feat .shape )} " )
284161 proj_seq = self .projector (feat )
285- try :
286- print (f"[debug] mineru2_visual.projector_seq_out shape={ tuple (proj_seq .shape )} (views, patch, hidden)" )
287- except Exception :
288- pass
162+
289163 proj = proj_seq .reshape (- 1 , proj_seq .shape [- 1 ])
290- try :
291- print (f"[debug] mineru2_visual.projector_flat_out shape={ tuple (proj .shape )} (views*patch, hidden)" )
292- except Exception :
293- pass
294164 return proj
295165
296166 def encode (self , images : List [ImageItem ]) -> Tuple [torch .Tensor , List [str ], List [List [int ]]]:
297- print (f"[debug] mineru2_visual encode images { len (images )} " )
298167 img_tensors : List [torch .Tensor ] = []
299168 uuids : List [str ] = []
300169 valid_id = 0
@@ -304,7 +173,7 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
304173 # 每视图 patch_len(例如 384/14=27, 27^2=729)
305174 patch_side = self .vision_tower .config .image_size // self .vision_tower .config .patch_size
306175 patch_len = patch_side * patch_side
307- print ( f"[debug] mineru2_visual.patch_len= { patch_len } (side= { patch_side } )" )
176+
308177 for i , img in enumerate (images ):
309178 if isinstance (img , ImageItem ):
310179 uuids .append (img .uuid )
@@ -321,59 +190,28 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
321190 t = self .image_processor .preprocess (image_data , return_tensors = "pt" )["pixel_values" ]
322191
323192 if t .ndim == 5 :
324- print (f"[debug] mineru2_visual reshape t.ndim: { t .ndim } , t.shape: { t .shape } " )
325193 t = t .view (- 1 , t .shape [- 3 ], t .shape [- 2 ], t .shape [- 1 ])
326194 elif t .ndim == 3 :
327- print (f"[debug] mineru2_visual unsqueeze t.ndim: { t .ndim } , t.shape: { t .shape } " )
328195 t = t .unsqueeze (0 )
329- # 在修改前记录 manager 分配的 token_num(可能是视图数或视图*patch数)
330- try :
331- print (f"[debug] mineru2_visual manager_token_num_before={ img .token_num } uuid={ img .uuid } " )
332- except Exception :
333- pass
196+
334197 # 对齐实际视图数 K 与期望 token(可能是 K 或 K*patch_len)
335198 expected_token = img .token_num if getattr (img , "token_num" , None ) is not None else None
336199 actual_k = t .shape [0 ]
337200 if expected_token is None or expected_token <= 0 :
338201 expected_views = actual_k
339- print (
340- f"[debug] mineru2_visual expected_views_from_actual uuid={ img .uuid } "
341- f" expected_views={ expected_views } "
342- )
343202 else :
344203 if expected_token >= patch_len and expected_token % patch_len == 0 :
345204 expected_views = expected_token // patch_len
346- print (
347- f"[debug] mineru2_visual expected_views_from_tokens uuid={ img .uuid } "
348- f" expected_token={ expected_token } patch_len={ patch_len } expected_views={ expected_views } "
349- )
350205 else :
351206 expected_views = expected_token
352- print (
353- f"[debug] mineru2_visual expected_views_interpret_as_views uuid={ img .uuid } "
354- f" expected_views={ expected_views } "
355- )
356207 if actual_k != expected_views :
357208 if actual_k % expected_views == 0 :
358209 factor = actual_k // expected_views
359- print (
360- f"[debug] mineru2_visual down_aggregate uuid={ img .uuid } "
361- f" actual_k={ actual_k } expected_views={ expected_views } factor={ factor } "
362- )
363210 t = t .view (expected_views , factor , t .shape [1 ], t .shape [2 ], t .shape [3 ]).mean (dim = 1 )
364211 elif expected_views % actual_k == 0 :
365212 factor = expected_views // actual_k
366- print (
367- f"[debug] mineru2_visual up_repeat uuid={ img .uuid } "
368- f" actual_k={ actual_k } expected_views={ expected_views } factor={ factor } "
369- )
370213 t = t .repeat_interleave (repeats = factor , dim = 0 )
371214 else :
372- k = min (actual_k , expected_views )
373- print (
374- f"[debug] mineru2_visual fallback_slice uuid={ img .uuid } "
375- f" actual_k={ actual_k } expected_views={ expected_views } k={ k } "
376- )
377215 if actual_k >= expected_views :
378216 t = t [:expected_views ]
379217 else :
@@ -385,10 +223,6 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
385223 final_views = t .shape [0 ]
386224 # 对齐 patch 序列后的总 token 数
387225 img .token_num = final_views * patch_len
388- print (
389- f"[debug] mineru2_visual actual_k={ actual_k } expected_views={ expected_views } "
390- f" final_views={ final_views } final_token_num={ img .token_num } uuid={ img .uuid } "
391- )
392226 else :
393227 raise Exception ("Unsupport input types: {} for {}" .format (type (img ), img ))
394228
@@ -398,10 +232,6 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
398232 else :
399233 cur_num = patch_len
400234 valid_ids .append ([valid_id , valid_id + cur_num ])
401- print (
402- f"[debug] mineru2_visual valid_ids_append uuid={ img .uuid } "
403- f" range=({ valid_id } ,{ valid_id + cur_num } ) cur_num={ cur_num } "
404- )
405235 valid_id += cur_num
406236
407237 if len (img_tensors ) <= 0 :
@@ -410,9 +240,5 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
410240 img = torch .cat (img_tensors , dim = 0 )
411241 img = img .cuda ()
412242 all_img_embeds = self .forward (img )
413- print (
414- f"[debug] mineru2_visual all_img_embeds.shape={ tuple (all_img_embeds .shape )} "
415- f" total_tokens={ img .shape [0 ] * patch_len } "
416- )
417243
418244 return all_img_embeds , uuids , valid_ids
0 commit comments