2727logger = init_logger (__name__ )
2828
2929
30- def build_vision_tower (config : Mineru2QwenConfig ):
30+ def build_vision_tower (weight_dir : str , config : Mineru2QwenConfig ):
3131 vision_tower = getattr (config , "mm_vision_tower" , getattr (config , "vision_tower" , "" ))
32- model_path = getattr ( config , "_name_or_path" , "" )
32+ model_path = os . path . join ( weight_dir , vision_tower )
3333
34- def _resolve_path (name ):
35- if model_path :
36- return f"{ model_path } /{ name } "
37- return name
34+ def _resolve_path ():
35+ if os .path .exists (model_path ):
36+ return model_path
37+ else :
38+ return vision_tower
3839
3940 if "clip" in vision_tower .lower ():
40- vt_path = _resolve_path (vision_tower )
41+ vt_path = _resolve_path ()
4142 print (f"[debug] load clip from { vt_path } " )
4243 return CLIPVisionModel .from_pretrained (vt_path )
4344 elif "siglip" in vision_tower .lower ():
44- vt_path = _resolve_path (vision_tower )
45+ vt_path = _resolve_path ()
4546 print (f"[debug] load siglip from { vt_path } " )
4647 # 方案A:使用配置减层并按该配置实例化模型,再加载权重(忽略不匹配尺寸)
4748 cfg = SiglipVisionConfig .from_pretrained (vt_path )
@@ -86,71 +87,60 @@ def __init__(self):
8687 pass
8788
8889 def _load_projector_weights (self , weight_dir : str ):
89- # 扫描 safetensors/bin 文件并尝试加载 projector 权重
90- def iter_state_dicts (dir_path : str ):
91- for f in os .listdir (dir_path ):
92- full = os .path .join (dir_path , f )
93- if not os .path .isfile (full ):
94- continue
95- if f .endswith (".safetensors" ):
96- try :
97- with safe_open (full , framework = "pt" , device = "cpu" ) as sf :
98- yield {k : sf .get_tensor (k ) for k in sf .keys ()}
99- except Exception as e :
100- print (f"[warning] safetensors read fail: { full } due to { e } " )
101- elif f .endswith (".bin" ):
102- try :
103- state = torch .load (full , map_location = "cpu" )
104- if isinstance (state , dict ):
105- yield state
106- except Exception as e :
107- print (f"[warning] bin read fail: { full } due to { e } " )
90+ projector_weight_path = os .path .join (weight_dir , "model.safetensors" )
91+ print (f"[debug] load projector weights from { projector_weight_path } " )
10892
10993 def assign_linear (linear : nn .Linear , w : torch .Tensor = None , b : torch .Tensor = None ):
11094 if w is not None :
11195 linear .weight .data .copy_ (w .to (dtype = linear .weight .dtype ))
11296 if b is not None and linear .bias is not None :
11397 linear .bias .data .copy_ (b .to (dtype = linear .bias .dtype ))
11498
115- def try_assign_from_keydict (key_to_tensor : dict ) -> bool :
116- # 兼容命名:
117- # - 线性:model.mm_projector.(weight|bias) / model.mm_projector.linear.(weight|bias)
118- # - 2层MLP:model.mm_projector.{0,2}.(weight|bias)
119- # - LLaVA风格别名:multi_modal_projector.linear_1 / linear_2
99+ # 收集 projector Linear 模块(顺序即写入顺序)
100+ if isinstance (self .projector , nn .Linear ):
101+ print (f"[debug] projector type: { type (self .projector )} " )
102+ linear_modules = [self .projector ]
103+ elif isinstance (self .projector , nn .Sequential ):
104+ print (f"[debug] projector type: { type (self .projector )} " )
105+ linear_modules = [m for m in self .projector if isinstance (m , nn .Linear )]
106+ else :
107+ print (f"[debug] projector type: { type (self .projector )} " )
108+ raise RuntimeError (f"Unsupported projector type: { type (self .projector )} " )
109+
110+ def assign_projector_from_state (sd : dict ) -> bool :
111+ # 单层线性:优先直接匹配整体权重;否则回退到首层
120112 if len (linear_modules ) == 1 :
121- w = None
122- b = None
123- for k in ("model.mm_projector.weight" , "model.mm_projector.linear.weight" ):
124- if k in key_to_tensor :
125- w = key_to_tensor [k ]
126- break
127- for k in ("model.mm_projector.bias" , "model.mm_projector.linear.bias" ):
128- if k in key_to_tensor :
129- b = key_to_tensor [k ]
130- break
113+ print ("[debug] projector load: single Linear matched (model.mm_projector.*)" )
114+ w = next (
115+ (sd [k ] for k in ("model.mm_projector.weight" , "model.mm_projector.linear.weight" ) if k in sd ), None
116+ )
117+ b = next (
118+ (sd [k ] for k in ("model.mm_projector.bias" , "model.mm_projector.linear.bias" ) if k in sd ), None
119+ )
131120 if w is not None :
132121 assign_linear (linear_modules [0 ], w , b )
133- print ("[debug] projector load: single Linear matched" )
134122 return True
135- # 兜底:若权重以分层形式存在,且本地只有一层,则尝试用第一层
136- for k in ("model.mm_projector.0.weight" , "multi_modal_projector.linear_1.weight" ):
137- if k in key_to_tensor :
138- w = key_to_tensor [k ]
139- break
140- for k in ("model.mm_projector.0.bias" , "multi_modal_projector.linear_1.bias" ):
141- if k in key_to_tensor :
142- b = key_to_tensor [k ]
143- break
123+ # 兜底:若分层存在,仅取第一层
124+ w = next (
125+ (
126+ sd [k ]
127+ for k in ("model.mm_projector.0.weight" , "multi_modal_projector.linear_1.weight" )
128+ if k in sd
129+ ),
130+ None ,
131+ )
132+ b = next (
133+ (sd [k ] for k in ("model.mm_projector.0.bias" , "multi_modal_projector.linear_1.bias" ) if k in sd ),
134+ None ,
135+ )
144136 if w is not None :
145137 assign_linear (linear_modules [0 ], w , b )
146138 print ("[debug] projector load: fallback to first layer for single Linear" )
147139 return True
148140 return False
149141
150142 # 多层(如 mlp2x_gelu):按常见命名逐一匹配
151- assigned = 0
152143 layer_key_map = [
153- # (idx, weight_keys, bias_keys)
154144 (
155145 0 ,
156146 ("model.mm_projector.0.weight" , "multi_modal_projector.linear_1.weight" ),
@@ -162,11 +152,12 @@ def try_assign_from_keydict(key_to_tensor: dict) -> bool:
162152 ("model.mm_projector.2.bias" , "multi_modal_projector.linear_2.bias" ),
163153 ),
164154 ]
155+ assigned = 0
165156 for idx , w_keys , b_keys in layer_key_map :
166157 if idx >= len (linear_modules ):
167158 continue
168- w = next ((key_to_tensor [k ] for k in w_keys if k in key_to_tensor ), None )
169- b = next ((key_to_tensor [k ] for k in b_keys if k in key_to_tensor ), None )
159+ w = next ((sd [k ] for k in w_keys if k in sd ), None )
160+ b = next ((sd [k ] for k in b_keys if k in sd ), None )
170161 if w is not None :
171162 assign_linear (linear_modules [idx ], w , b )
172163 assigned += 1
@@ -175,33 +166,75 @@ def try_assign_from_keydict(key_to_tensor: dict) -> bool:
175166 return True
176167 return False
177168
178- # 收集本地 Linear 模块(顺序即写入顺序)
179- if isinstance (self .projector , nn .Linear ):
180- linear_modules = [self .projector ]
181- elif isinstance (self .projector , nn .Sequential ):
182- linear_modules = [m for m in self .projector if isinstance (m , nn .Linear )]
169+ def try_load_vision_tower (sd : dict ):
170+ # 参考 ref: 去掉前缀 'model.vision_tower.vision_tower.' 进行加载(可选)
171+ if not hasattr (self , "vision_tower" ) or not isinstance (
172+ self .vision_tower , (CLIPVisionModel , SiglipVisionModel )
173+ ):
174+ return False , 0
175+ vt_prefix = "model.vision_tower.vision_tower."
176+ vt_sd = {k [len (vt_prefix ) :]: v for k , v in sd .items () if k .startswith (vt_prefix )}
177+ if not vt_sd :
178+ return False , 0
179+ try :
180+ missing , unexpected = self .vision_tower .load_state_dict (vt_sd , strict = False )
181+ num = len (vt_sd )
182+ print (
183+ f"[debug] vision_tower load: keys={ num } "
184+ f" missing={ len (missing ) if isinstance (missing , (list , tuple )) else 'n/a' } "
185+ f" unexpected={ len (unexpected ) if isinstance (unexpected , (list , tuple )) else 'n/a' } "
186+ )
187+ return True , num
188+ except Exception as e :
189+ print (f"[warning] vision_tower load_state_dict failed (strict=False): { e } " )
190+ return False , 0
191+
192+ # 仅从指定文件加载(优先 .safetensors,fallback 到同名 .bin)
193+ if os .path .isfile (projector_weight_path ) and projector_weight_path .endswith (".safetensors" ):
194+ try :
195+ with safe_open (projector_weight_path , framework = "pt" , device = "cpu" ) as sf :
196+ sd = {k : sf .get_tensor (k ) for k in sf .keys ()}
197+ except Exception as e :
198+ raise RuntimeError (f"Failed to read projector weights: { projector_weight_path } due to { e } " )
183199 else :
184- raise RuntimeError (f"Unsupported projector type: { type (self .projector )} " )
185-
186- found = False
187- for sd in iter_state_dicts (weight_dir ):
188- if try_assign_from_keydict (sd ):
189- found = True
190- break
200+ bin_path = (
201+ projector_weight_path [:- 14 ] + ".bin"
202+ if projector_weight_path .endswith (".safetensors" )
203+ else projector_weight_path
204+ )
205+ if os .path .isfile (bin_path ):
206+ try :
207+ sd = torch .load (bin_path , map_location = "cpu" )
208+ if not isinstance (sd , dict ):
209+ raise RuntimeError ("Loaded non-dict state from bin file" )
210+ print (f"[debug] fallback load projector weights from { bin_path } " )
211+ except Exception as e :
212+ raise RuntimeError (f"Failed to read projector weights: { bin_path } due to { e } " )
213+ else :
214+ raise RuntimeError (f"Projector weight file not found: { projector_weight_path } " )
191215
192- if not found :
216+ # 加载 projector(必要)
217+ projector_loaded = assign_projector_from_state (sd )
218+ if not projector_loaded :
193219 raise RuntimeError (
194220 "Projector weights not found in checkpoint. "
195221 "Expected keys like 'model.mm_projector.{0,2}.(weight|bias)' or "
196222 "'multi_modal_projector.linear_{1,2}.(weight|bias)' "
197223 "or 'model.mm_projector.(weight|bias)'."
198224 )
199225
226+ # 可选加载 vision_tower
227+ vision_loaded , vision_loaded_keys = try_load_vision_tower (sd )
228+ if vision_loaded :
229+ print (f"[debug] vision_tower weights loaded from checkpoint: keys={ vision_loaded_keys } " )
230+ else :
231+ print ("[debug] vision_tower weights not found in checkpoint or skipped; keep pretrained weights" )
232+
200233 def load_model (self , weight_dir ):
201234 print (f"[debug] load vision model: { weight_dir } " )
202235 vision_config = Mineru2QwenConfig .from_pretrained (weight_dir )
203236
204- self .vision_tower = build_vision_tower (vision_config )
237+ self .vision_tower = build_vision_tower (weight_dir , vision_config )
205238 self .vision_tower .eval ()
206239 self .vision_tower .requires_grad_ (False )
207240 self .projector = build_vision_projector (vision_config )
0 commit comments