Skip to content

Commit 1990422

Browse files
committed
fix
1 parent 59ec94b commit 1990422

File tree

1 file changed

+49
-223
lines changed

1 file changed

+49
-223
lines changed

lightllm/models/mineru2_qwen/mineru2_visual.py

Lines changed: 49 additions & 223 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import numpy as np
1010
from transformers import (
1111
CLIPVisionModel,
12-
CLIPVisionConfig,
1312
SiglipVisionConfig,
1413
SiglipVisionModel,
1514
)
@@ -30,32 +29,16 @@
3029
def build_vision_tower(weight_dir: str, config: Mineru2QwenConfig):
3130
vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
3231
model_path = os.path.join(weight_dir, vision_tower)
33-
34-
def _resolve_path():
35-
if os.path.exists(model_path):
36-
return model_path
37-
else:
38-
return vision_tower
32+
if not os.path.exists(model_path):
33+
model_path = vision_tower
3934

4035
if "clip" in vision_tower.lower():
41-
vt_path = _resolve_path()
42-
print(f"[debug] load clip from {vt_path}")
43-
return CLIPVisionModel.from_pretrained(vt_path)
36+
return CLIPVisionModel.from_pretrained(model_path)
4437
elif "siglip" in vision_tower.lower():
45-
vt_path = _resolve_path()
46-
print(f"[debug] load siglip from {vt_path}")
47-
# 方案A:使用配置减层并按该配置实例化模型,再加载权重(忽略不匹配尺寸)
48-
cfg = SiglipVisionConfig.from_pretrained(vt_path)
49-
old_layers = getattr(cfg, "num_hidden_layers", None)
38+
cfg = SiglipVisionConfig.from_pretrained(model_path)
5039
cfg.num_hidden_layers = max(0, cfg.num_hidden_layers - 1)
5140
cfg.vision_use_head = False
52-
model = SiglipVisionModel.from_pretrained(vt_path, config=cfg, ignore_mismatched_sizes=True)
53-
try:
54-
actual_layers = len(model.vision_model.encoder.layers) # type: ignore[attr-defined]
55-
except Exception:
56-
actual_layers = None
57-
new_cfg_layers = getattr(getattr(model, "config", None), "num_hidden_layers", None)
58-
print(f"[debug] siglip_layers planA old={old_layers} new_cfg={new_cfg_layers} actual_module={actual_layers}")
41+
model = SiglipVisionModel.from_pretrained(model_path, config=cfg, ignore_mismatched_sizes=True)
5942
return model
6043
else:
6144
raise ValueError(f"Unknown vision tower: {vision_tower}")
@@ -87,151 +70,61 @@ def __init__(self):
8770
pass
8871

8972
def _load_projector_weights(self, weight_dir: str):
90-
projector_weight_path = os.path.join(weight_dir, "model.safetensors")
91-
print(f"[debug] load projector weights from {projector_weight_path}")
92-
9373
def assign_linear(linear: nn.Linear, w: torch.Tensor = None, b: torch.Tensor = None):
9474
if w is not None:
9575
linear.weight.data.copy_(w.to(dtype=linear.weight.dtype))
9676
if b is not None and linear.bias is not None:
9777
linear.bias.data.copy_(b.to(dtype=linear.bias.dtype))
9878

99-
# 收集 projector Linear 模块(顺序即写入顺序)
79+
projector_weight_path = os.path.join(weight_dir, "model.safetensors")
80+
10081
if isinstance(self.projector, nn.Linear):
101-
print(f"[debug] projector type: {type(self.projector)}")
10282
linear_modules = [self.projector]
10383
elif isinstance(self.projector, nn.Sequential):
104-
print(f"[debug] projector type: {type(self.projector)}")
10584
linear_modules = [m for m in self.projector if isinstance(m, nn.Linear)]
10685
else:
107-
print(f"[debug] projector type: {type(self.projector)}")
10886
raise RuntimeError(f"Unsupported projector type: {type(self.projector)}")
10987

110-
def assign_projector_from_state(sd: dict) -> bool:
111-
# 单层线性:优先直接匹配整体权重;否则回退到首层
112-
if len(linear_modules) == 1:
113-
print("[debug] projector load: single Linear matched (model.mm_projector.*)")
114-
w = next(
115-
(sd[k] for k in ("model.mm_projector.weight", "model.mm_projector.linear.weight") if k in sd), None
116-
)
117-
b = next(
118-
(sd[k] for k in ("model.mm_projector.bias", "model.mm_projector.linear.bias") if k in sd), None
119-
)
120-
if w is not None:
121-
assign_linear(linear_modules[0], w, b)
122-
return True
123-
# 兜底:若分层存在,仅取第一层
124-
w = next(
125-
(
126-
sd[k]
127-
for k in ("model.mm_projector.0.weight", "multi_modal_projector.linear_1.weight")
128-
if k in sd
129-
),
130-
None,
131-
)
132-
b = next(
133-
(sd[k] for k in ("model.mm_projector.0.bias", "multi_modal_projector.linear_1.bias") if k in sd),
134-
None,
135-
)
136-
if w is not None:
137-
assign_linear(linear_modules[0], w, b)
138-
print("[debug] projector load: fallback to first layer for single Linear")
139-
return True
140-
return False
141-
142-
# 多层(如 mlp2x_gelu):按常见命名逐一匹配
143-
layer_key_map = [
144-
(
145-
0,
146-
("model.mm_projector.0.weight", "multi_modal_projector.linear_1.weight"),
147-
("model.mm_projector.0.bias", "multi_modal_projector.linear_1.bias"),
148-
),
149-
(
150-
1,
151-
("model.mm_projector.2.weight", "multi_modal_projector.linear_2.weight"),
152-
("model.mm_projector.2.bias", "multi_modal_projector.linear_2.bias"),
153-
),
154-
]
155-
assigned = 0
156-
for idx, w_keys, b_keys in layer_key_map:
157-
if idx >= len(linear_modules):
158-
continue
159-
w = next((sd[k] for k in w_keys if k in sd), None)
160-
b = next((sd[k] for k in b_keys if k in sd), None)
161-
if w is not None:
162-
assign_linear(linear_modules[idx], w, b)
163-
assigned += 1
164-
if assigned > 0:
165-
print(f"[debug] projector load: assigned {assigned} Linear layers")
166-
return True
167-
return False
168-
169-
def try_load_vision_tower(sd: dict):
170-
# 参考 ref: 去掉前缀 'model.vision_tower.vision_tower.' 进行加载(可选)
171-
if not hasattr(self, "vision_tower") or not isinstance(
172-
self.vision_tower, (CLIPVisionModel, SiglipVisionModel)
173-
):
174-
return False, 0
175-
vt_prefix = "model.vision_tower.vision_tower."
176-
vt_sd = {k[len(vt_prefix) :]: v for k, v in sd.items() if k.startswith(vt_prefix)}
177-
if not vt_sd:
178-
return False, 0
179-
try:
180-
missing, unexpected = self.vision_tower.load_state_dict(vt_sd, strict=False)
181-
num = len(vt_sd)
182-
print(
183-
f"[debug] vision_tower load: keys={num}"
184-
f" missing={len(missing) if isinstance(missing, (list, tuple)) else 'n/a'}"
185-
f" unexpected={len(unexpected) if isinstance(unexpected, (list, tuple)) else 'n/a'}"
186-
)
187-
return True, num
188-
except Exception as e:
189-
print(f"[warning] vision_tower load_state_dict failed (strict=False): {e}")
190-
return False, 0
191-
192-
# 仅从指定文件加载(优先 .safetensors,fallback 到同名 .bin)
193-
if os.path.isfile(projector_weight_path) and projector_weight_path.endswith(".safetensors"):
194-
try:
195-
with safe_open(projector_weight_path, framework="pt", device="cpu") as sf:
196-
sd = {k: sf.get_tensor(k) for k in sf.keys()}
197-
except Exception as e:
198-
raise RuntimeError(f"Failed to read projector weights: {projector_weight_path} due to {e}")
199-
else:
200-
bin_path = (
201-
projector_weight_path[:-14] + ".bin"
202-
if projector_weight_path.endswith(".safetensors")
203-
else projector_weight_path
204-
)
205-
if os.path.isfile(bin_path):
206-
try:
207-
sd = torch.load(bin_path, map_location="cpu")
208-
if not isinstance(sd, dict):
209-
raise RuntimeError("Loaded non-dict state from bin file")
210-
print(f"[debug] fallback load projector weights from {bin_path}")
211-
except Exception as e:
212-
raise RuntimeError(f"Failed to read projector weights: {bin_path} due to {e}")
213-
else:
214-
raise RuntimeError(f"Projector weight file not found: {projector_weight_path}")
215-
216-
# 加载 projector(必要)
217-
projector_loaded = assign_projector_from_state(sd)
218-
if not projector_loaded:
219-
raise RuntimeError(
220-
"Projector weights not found in checkpoint. "
221-
"Expected keys like 'model.mm_projector.{0,2}.(weight|bias)' or "
222-
"'multi_modal_projector.linear_{1,2}.(weight|bias)' "
223-
"or 'model.mm_projector.(weight|bias)'."
224-
)
225-
226-
# 可选加载 vision_tower
227-
vision_loaded, vision_loaded_keys = try_load_vision_tower(sd)
228-
if vision_loaded:
229-
print(f"[debug] vision_tower weights loaded from checkpoint: keys={vision_loaded_keys}")
230-
else:
231-
print("[debug] vision_tower weights not found in checkpoint or skipped; keep pretrained weights")
88+
try:
89+
with safe_open(projector_weight_path, framework="pt", device="cpu") as sf:
90+
sd = {k: sf.get_tensor(k) for k in sf.keys()}
91+
except Exception as e:
92+
raise RuntimeError(f"Failed to read projector weights: {projector_weight_path} due to {e}")
93+
94+
# load projector weights
95+
layer_key_map = [
96+
(
97+
0,
98+
("model.mm_projector.0.weight", "multi_modal_projector.linear_1.weight"),
99+
("model.mm_projector.0.bias", "multi_modal_projector.linear_1.bias"),
100+
),
101+
(
102+
1,
103+
("model.mm_projector.2.weight", "multi_modal_projector.linear_2.weight"),
104+
("model.mm_projector.2.bias", "multi_modal_projector.linear_2.bias"),
105+
),
106+
]
107+
for idx, w_keys, b_keys in layer_key_map:
108+
if idx >= len(linear_modules):
109+
continue
110+
w = next((sd[k] for k in w_keys if k in sd), None)
111+
b = next((sd[k] for k in b_keys if k in sd), None)
112+
if w is not None:
113+
assign_linear(linear_modules[idx], w, b)
114+
115+
# load vision tower weights
116+
vt_prefix = "model.vision_tower.vision_tower."
117+
vt_sd = {k[len(vt_prefix) :]: v for k, v in sd.items() if k.startswith(vt_prefix)}
118+
if not vt_sd:
119+
logger.warning("vision_tower weights not found in checkpoint or skipped; keep pretrained weights")
120+
return
121+
122+
try:
123+
self.vision_tower.load_state_dict(vt_sd, strict=False)
124+
except Exception as e:
125+
logger.warning(f"vision_tower load_state_dict failed (strict=False): {e}")
232126

233127
def load_model(self, weight_dir):
234-
print(f"[debug] load vision model: {weight_dir}")
235128
vision_config = Mineru2QwenConfig.from_pretrained(weight_dir)
236129

237130
self.vision_tower = build_vision_tower(weight_dir, vision_config)
@@ -251,50 +144,26 @@ def cuda(self):
251144
return self
252145

253146
def forward(self, x) -> torch.Tensor:
254-
# 运行时形状与精度/设备检查
255-
try:
256-
print(f"[debug] mineru2_visual.forward x.shape={tuple(x.shape)} dtype={x.dtype} device={x.device}")
257-
except Exception:
258-
pass
259147
vision_out = self.vision_tower(x, output_hidden_states=True)
260148
hiddens = vision_out.hidden_states
261-
# hidden_states 数量与 config 层数的关系(一般为 num_layers + 1)
262-
try:
263-
cfg_layers = getattr(getattr(self.vision_tower, "config", None), "num_hidden_layers", None)
264-
eff_layers = len(hiddens) - 1 if isinstance(hiddens, (list, tuple)) else None
265-
print(
266-
f"[debug] mineru2_visual.hidden_states len={len(hiddens)}"
267-
f" cfg_layers={cfg_layers} eff_layers={eff_layers}"
268-
)
269-
except Exception:
270-
pass
149+
271150
# 对齐ref的“减一层”语义:优先使用倒数第二层;若不可用则回退最后一层
272151
try:
273152
chosen_idx = -2 if isinstance(hiddens, (list, tuple)) and len(hiddens) >= 2 else -1
274153
feat = hiddens[chosen_idx]
275-
print(f"[debug] mineru2_visual.select_layer idx={chosen_idx} feat.shape={tuple(feat.shape)}")
276154
except Exception:
277155
feat = hiddens[-2] if isinstance(hiddens, (list, tuple)) and len(hiddens) >= 2 else hiddens[-1]
278156
# 切回 patch 序列特征:去除 CLS(若存在),按序列过 projector,再展平为 (views*patch, hidden)
279157
patch_side = self.vision_tower.config.image_size // self.vision_tower.config.patch_size
280158
patch_len = patch_side * patch_side
281159
if feat.shape[1] == patch_len + 1:
282160
feat = feat[:, 1:, :]
283-
print(f"[debug] mineru2_visual.drop_cls patch_len={patch_len} feat_no_cls.shape={tuple(feat.shape)}")
284161
proj_seq = self.projector(feat)
285-
try:
286-
print(f"[debug] mineru2_visual.projector_seq_out shape={tuple(proj_seq.shape)} (views, patch, hidden)")
287-
except Exception:
288-
pass
162+
289163
proj = proj_seq.reshape(-1, proj_seq.shape[-1])
290-
try:
291-
print(f"[debug] mineru2_visual.projector_flat_out shape={tuple(proj.shape)} (views*patch, hidden)")
292-
except Exception:
293-
pass
294164
return proj
295165

296166
def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List[List[int]]]:
297-
print(f"[debug] mineru2_visual encode images {len(images)}")
298167
img_tensors: List[torch.Tensor] = []
299168
uuids: List[str] = []
300169
valid_id = 0
@@ -304,7 +173,7 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
304173
# 每视图 patch_len(例如 384/14=27, 27^2=729)
305174
patch_side = self.vision_tower.config.image_size // self.vision_tower.config.patch_size
306175
patch_len = patch_side * patch_side
307-
print(f"[debug] mineru2_visual.patch_len={patch_len} (side={patch_side})")
176+
308177
for i, img in enumerate(images):
309178
if isinstance(img, ImageItem):
310179
uuids.append(img.uuid)
@@ -321,59 +190,28 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
321190
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
322191

323192
if t.ndim == 5:
324-
print(f"[debug] mineru2_visual reshape t.ndim: {t.ndim}, t.shape: {t.shape}")
325193
t = t.view(-1, t.shape[-3], t.shape[-2], t.shape[-1])
326194
elif t.ndim == 3:
327-
print(f"[debug] mineru2_visual unsqueeze t.ndim: {t.ndim}, t.shape: {t.shape}")
328195
t = t.unsqueeze(0)
329-
# 在修改前记录 manager 分配的 token_num(可能是视图数或视图*patch数)
330-
try:
331-
print(f"[debug] mineru2_visual manager_token_num_before={img.token_num} uuid={img.uuid}")
332-
except Exception:
333-
pass
196+
334197
# 对齐实际视图数 K 与期望 token(可能是 K 或 K*patch_len)
335198
expected_token = img.token_num if getattr(img, "token_num", None) is not None else None
336199
actual_k = t.shape[0]
337200
if expected_token is None or expected_token <= 0:
338201
expected_views = actual_k
339-
print(
340-
f"[debug] mineru2_visual expected_views_from_actual uuid={img.uuid}"
341-
f" expected_views={expected_views}"
342-
)
343202
else:
344203
if expected_token >= patch_len and expected_token % patch_len == 0:
345204
expected_views = expected_token // patch_len
346-
print(
347-
f"[debug] mineru2_visual expected_views_from_tokens uuid={img.uuid}"
348-
f" expected_token={expected_token} patch_len={patch_len} expected_views={expected_views}"
349-
)
350205
else:
351206
expected_views = expected_token
352-
print(
353-
f"[debug] mineru2_visual expected_views_interpret_as_views uuid={img.uuid}"
354-
f" expected_views={expected_views}"
355-
)
356207
if actual_k != expected_views:
357208
if actual_k % expected_views == 0:
358209
factor = actual_k // expected_views
359-
print(
360-
f"[debug] mineru2_visual down_aggregate uuid={img.uuid}"
361-
f" actual_k={actual_k} expected_views={expected_views} factor={factor}"
362-
)
363210
t = t.view(expected_views, factor, t.shape[1], t.shape[2], t.shape[3]).mean(dim=1)
364211
elif expected_views % actual_k == 0:
365212
factor = expected_views // actual_k
366-
print(
367-
f"[debug] mineru2_visual up_repeat uuid={img.uuid}"
368-
f" actual_k={actual_k} expected_views={expected_views} factor={factor}"
369-
)
370213
t = t.repeat_interleave(repeats=factor, dim=0)
371214
else:
372-
k = min(actual_k, expected_views)
373-
print(
374-
f"[debug] mineru2_visual fallback_slice uuid={img.uuid}"
375-
f" actual_k={actual_k} expected_views={expected_views} k={k}"
376-
)
377215
if actual_k >= expected_views:
378216
t = t[:expected_views]
379217
else:
@@ -385,10 +223,6 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
385223
final_views = t.shape[0]
386224
# 对齐 patch 序列后的总 token 数
387225
img.token_num = final_views * patch_len
388-
print(
389-
f"[debug] mineru2_visual actual_k={actual_k} expected_views={expected_views}"
390-
f" final_views={final_views} final_token_num={img.token_num} uuid={img.uuid}"
391-
)
392226
else:
393227
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
394228

@@ -398,10 +232,6 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
398232
else:
399233
cur_num = patch_len
400234
valid_ids.append([valid_id, valid_id + cur_num])
401-
print(
402-
f"[debug] mineru2_visual valid_ids_append uuid={img.uuid}"
403-
f" range=({valid_id},{valid_id + cur_num}) cur_num={cur_num}"
404-
)
405235
valid_id += cur_num
406236

407237
if len(img_tensors) <= 0:
@@ -410,9 +240,5 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
410240
img = torch.cat(img_tensors, dim=0)
411241
img = img.cuda()
412242
all_img_embeds = self.forward(img)
413-
print(
414-
f"[debug] mineru2_visual all_img_embeds.shape={tuple(all_img_embeds.shape)}"
415-
f" total_tokens={img.shape[0] * patch_len}"
416-
)
417243

418244
return all_img_embeds, uuids, valid_ids

0 commit comments

Comments
 (0)