Skip to content

Commit 797fc55

Browse files
committed
feat: can recognize
1 parent 93625f0 commit 797fc55

File tree

2 files changed

+104
-71
lines changed

2 files changed

+104
-71
lines changed

lightllm/models/mineru2_qwen/mineru2_visual.py

Lines changed: 103 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,22 @@
2727
logger = init_logger(__name__)
2828

2929

30-
def build_vision_tower(config: Mineru2QwenConfig):
30+
def build_vision_tower(weight_dir: str, config: Mineru2QwenConfig):
3131
vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
32-
model_path = getattr(config, "_name_or_path", "")
32+
model_path = os.path.join(weight_dir, vision_tower)
3333

34-
def _resolve_path(name):
35-
if model_path:
36-
return f"{model_path}/{name}"
37-
return name
34+
def _resolve_path():
35+
if os.path.exists(model_path):
36+
return model_path
37+
else:
38+
return vision_tower
3839

3940
if "clip" in vision_tower.lower():
40-
vt_path = _resolve_path(vision_tower)
41+
vt_path = _resolve_path()
4142
print(f"[debug] load clip from {vt_path}")
4243
return CLIPVisionModel.from_pretrained(vt_path)
4344
elif "siglip" in vision_tower.lower():
44-
vt_path = _resolve_path(vision_tower)
45+
vt_path = _resolve_path()
4546
print(f"[debug] load siglip from {vt_path}")
4647
# 方案A:使用配置减层并按该配置实例化模型,再加载权重(忽略不匹配尺寸)
4748
cfg = SiglipVisionConfig.from_pretrained(vt_path)
@@ -86,71 +87,60 @@ def __init__(self):
8687
pass
8788

8889
def _load_projector_weights(self, weight_dir: str):
89-
# 扫描 safetensors/bin 文件并尝试加载 projector 权重
90-
def iter_state_dicts(dir_path: str):
91-
for f in os.listdir(dir_path):
92-
full = os.path.join(dir_path, f)
93-
if not os.path.isfile(full):
94-
continue
95-
if f.endswith(".safetensors"):
96-
try:
97-
with safe_open(full, framework="pt", device="cpu") as sf:
98-
yield {k: sf.get_tensor(k) for k in sf.keys()}
99-
except Exception as e:
100-
print(f"[warning] safetensors read fail: {full} due to {e}")
101-
elif f.endswith(".bin"):
102-
try:
103-
state = torch.load(full, map_location="cpu")
104-
if isinstance(state, dict):
105-
yield state
106-
except Exception as e:
107-
print(f"[warning] bin read fail: {full} due to {e}")
90+
projector_weight_path = os.path.join(weight_dir, "model.safetensors")
91+
print(f"[debug] load projector weights from {projector_weight_path}")
10892

10993
def assign_linear(linear: nn.Linear, w: torch.Tensor = None, b: torch.Tensor = None):
11094
if w is not None:
11195
linear.weight.data.copy_(w.to(dtype=linear.weight.dtype))
11296
if b is not None and linear.bias is not None:
11397
linear.bias.data.copy_(b.to(dtype=linear.bias.dtype))
11498

115-
def try_assign_from_keydict(key_to_tensor: dict) -> bool:
116-
# 兼容命名:
117-
# - 线性:model.mm_projector.(weight|bias) / model.mm_projector.linear.(weight|bias)
118-
# - 2层MLP:model.mm_projector.{0,2}.(weight|bias)
119-
# - LLaVA风格别名:multi_modal_projector.linear_1 / linear_2
99+
# 收集 projector Linear 模块(顺序即写入顺序)
100+
if isinstance(self.projector, nn.Linear):
101+
print(f"[debug] projector type: {type(self.projector)}")
102+
linear_modules = [self.projector]
103+
elif isinstance(self.projector, nn.Sequential):
104+
print(f"[debug] projector type: {type(self.projector)}")
105+
linear_modules = [m for m in self.projector if isinstance(m, nn.Linear)]
106+
else:
107+
print(f"[debug] projector type: {type(self.projector)}")
108+
raise RuntimeError(f"Unsupported projector type: {type(self.projector)}")
109+
110+
def assign_projector_from_state(sd: dict) -> bool:
111+
# 单层线性:优先直接匹配整体权重;否则回退到首层
120112
if len(linear_modules) == 1:
121-
w = None
122-
b = None
123-
for k in ("model.mm_projector.weight", "model.mm_projector.linear.weight"):
124-
if k in key_to_tensor:
125-
w = key_to_tensor[k]
126-
break
127-
for k in ("model.mm_projector.bias", "model.mm_projector.linear.bias"):
128-
if k in key_to_tensor:
129-
b = key_to_tensor[k]
130-
break
113+
print("[debug] projector load: single Linear matched (model.mm_projector.*)")
114+
w = next(
115+
(sd[k] for k in ("model.mm_projector.weight", "model.mm_projector.linear.weight") if k in sd), None
116+
)
117+
b = next(
118+
(sd[k] for k in ("model.mm_projector.bias", "model.mm_projector.linear.bias") if k in sd), None
119+
)
131120
if w is not None:
132121
assign_linear(linear_modules[0], w, b)
133-
print("[debug] projector load: single Linear matched")
134122
return True
135-
# 兜底:若权重以分层形式存在,且本地只有一层,则尝试用第一层
136-
for k in ("model.mm_projector.0.weight", "multi_modal_projector.linear_1.weight"):
137-
if k in key_to_tensor:
138-
w = key_to_tensor[k]
139-
break
140-
for k in ("model.mm_projector.0.bias", "multi_modal_projector.linear_1.bias"):
141-
if k in key_to_tensor:
142-
b = key_to_tensor[k]
143-
break
123+
# 兜底:若分层存在,仅取第一层
124+
w = next(
125+
(
126+
sd[k]
127+
for k in ("model.mm_projector.0.weight", "multi_modal_projector.linear_1.weight")
128+
if k in sd
129+
),
130+
None,
131+
)
132+
b = next(
133+
(sd[k] for k in ("model.mm_projector.0.bias", "multi_modal_projector.linear_1.bias") if k in sd),
134+
None,
135+
)
144136
if w is not None:
145137
assign_linear(linear_modules[0], w, b)
146138
print("[debug] projector load: fallback to first layer for single Linear")
147139
return True
148140
return False
149141

150142
# 多层(如 mlp2x_gelu):按常见命名逐一匹配
151-
assigned = 0
152143
layer_key_map = [
153-
# (idx, weight_keys, bias_keys)
154144
(
155145
0,
156146
("model.mm_projector.0.weight", "multi_modal_projector.linear_1.weight"),
@@ -162,11 +152,12 @@ def try_assign_from_keydict(key_to_tensor: dict) -> bool:
162152
("model.mm_projector.2.bias", "multi_modal_projector.linear_2.bias"),
163153
),
164154
]
155+
assigned = 0
165156
for idx, w_keys, b_keys in layer_key_map:
166157
if idx >= len(linear_modules):
167158
continue
168-
w = next((key_to_tensor[k] for k in w_keys if k in key_to_tensor), None)
169-
b = next((key_to_tensor[k] for k in b_keys if k in key_to_tensor), None)
159+
w = next((sd[k] for k in w_keys if k in sd), None)
160+
b = next((sd[k] for k in b_keys if k in sd), None)
170161
if w is not None:
171162
assign_linear(linear_modules[idx], w, b)
172163
assigned += 1
@@ -175,33 +166,75 @@ def try_assign_from_keydict(key_to_tensor: dict) -> bool:
175166
return True
176167
return False
177168

178-
# 收集本地 Linear 模块(顺序即写入顺序)
179-
if isinstance(self.projector, nn.Linear):
180-
linear_modules = [self.projector]
181-
elif isinstance(self.projector, nn.Sequential):
182-
linear_modules = [m for m in self.projector if isinstance(m, nn.Linear)]
169+
def try_load_vision_tower(sd: dict):
170+
# 参考 ref: 去掉前缀 'model.vision_tower.vision_tower.' 进行加载(可选)
171+
if not hasattr(self, "vision_tower") or not isinstance(
172+
self.vision_tower, (CLIPVisionModel, SiglipVisionModel)
173+
):
174+
return False, 0
175+
vt_prefix = "model.vision_tower.vision_tower."
176+
vt_sd = {k[len(vt_prefix) :]: v for k, v in sd.items() if k.startswith(vt_prefix)}
177+
if not vt_sd:
178+
return False, 0
179+
try:
180+
missing, unexpected = self.vision_tower.load_state_dict(vt_sd, strict=False)
181+
num = len(vt_sd)
182+
print(
183+
f"[debug] vision_tower load: keys={num}"
184+
f" missing={len(missing) if isinstance(missing, (list, tuple)) else 'n/a'}"
185+
f" unexpected={len(unexpected) if isinstance(unexpected, (list, tuple)) else 'n/a'}"
186+
)
187+
return True, num
188+
except Exception as e:
189+
print(f"[warning] vision_tower load_state_dict failed (strict=False): {e}")
190+
return False, 0
191+
192+
# 仅从指定文件加载(优先 .safetensors,fallback 到同名 .bin)
193+
if os.path.isfile(projector_weight_path) and projector_weight_path.endswith(".safetensors"):
194+
try:
195+
with safe_open(projector_weight_path, framework="pt", device="cpu") as sf:
196+
sd = {k: sf.get_tensor(k) for k in sf.keys()}
197+
except Exception as e:
198+
raise RuntimeError(f"Failed to read projector weights: {projector_weight_path} due to {e}")
183199
else:
184-
raise RuntimeError(f"Unsupported projector type: {type(self.projector)}")
185-
186-
found = False
187-
for sd in iter_state_dicts(weight_dir):
188-
if try_assign_from_keydict(sd):
189-
found = True
190-
break
200+
bin_path = (
201+
projector_weight_path[:-14] + ".bin"
202+
if projector_weight_path.endswith(".safetensors")
203+
else projector_weight_path
204+
)
205+
if os.path.isfile(bin_path):
206+
try:
207+
sd = torch.load(bin_path, map_location="cpu")
208+
if not isinstance(sd, dict):
209+
raise RuntimeError("Loaded non-dict state from bin file")
210+
print(f"[debug] fallback load projector weights from {bin_path}")
211+
except Exception as e:
212+
raise RuntimeError(f"Failed to read projector weights: {bin_path} due to {e}")
213+
else:
214+
raise RuntimeError(f"Projector weight file not found: {projector_weight_path}")
191215

192-
if not found:
216+
# 加载 projector(必要)
217+
projector_loaded = assign_projector_from_state(sd)
218+
if not projector_loaded:
193219
raise RuntimeError(
194220
"Projector weights not found in checkpoint. "
195221
"Expected keys like 'model.mm_projector.{0,2}.(weight|bias)' or "
196222
"'multi_modal_projector.linear_{1,2}.(weight|bias)' "
197223
"or 'model.mm_projector.(weight|bias)'."
198224
)
199225

226+
# 可选加载 vision_tower
227+
vision_loaded, vision_loaded_keys = try_load_vision_tower(sd)
228+
if vision_loaded:
229+
print(f"[debug] vision_tower weights loaded from checkpoint: keys={vision_loaded_keys}")
230+
else:
231+
print("[debug] vision_tower weights not found in checkpoint or skipped; keep pretrained weights")
232+
200233
def load_model(self, weight_dir):
201234
print(f"[debug] load vision model: {weight_dir}")
202235
vision_config = Mineru2QwenConfig.from_pretrained(weight_dir)
203236

204-
self.vision_tower = build_vision_tower(vision_config)
237+
self.vision_tower = build_vision_tower(weight_dir, vision_config)
205238
self.vision_tower.eval()
206239
self.vision_tower.requires_grad_(False)
207240
self.projector = build_vision_projector(vision_config)

mm_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def run(query, uris):
1616
data = {
1717
"inputs": query,
1818
"parameters": {
19-
"max_new_tokens": 128,
19+
"max_new_tokens": 512,
2020
"ignore_eos": False,
2121
# The space before <|endoftext|> is important,
2222
# the server will remove the first bos_token_id,

0 commit comments

Comments
 (0)