Skip to content

Commit abe4402

Browse files
committed
feat: check token num
1 parent 3b05e31 commit abe4402

File tree

2 files changed

+127
-32
lines changed

2 files changed

+127
-32
lines changed

lightllm/models/mineru2_qwen/mineru2_visual.py

Lines changed: 92 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
import torch.nn as nn
9+
import torch.nn.functional as F
910
import numpy as np
1011
from transformers import (
1112
CLIPVisionModel,
@@ -14,7 +15,12 @@
1415
)
1516

1617
from .configuration_mineru2 import Mineru2QwenConfig
17-
from .image_processing_mineru2 import Mineru2ImageProcessor, expand2square, process_anyres_image
18+
from .image_processing_mineru2 import (
19+
Mineru2ImageProcessor,
20+
expand2square,
21+
process_anyres_image,
22+
get_anyres_image_grid_shape,
23+
)
1824

1925
from lightllm.server.multimodal_params import ImageItem
2026
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
@@ -179,7 +185,9 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
179185
uuids.append(img.uuid)
180186
image_data = read_shm(get_shm_name_data(img.uuid))
181187
image_data = Image.open(BytesIO(image_data)).convert("RGB")
182-
if image_aspect_ratio == "pad":
188+
# 多图/视频强制 pad,单图才允许 anyres
189+
force_pad = len(images) > 1
190+
if image_aspect_ratio == "pad" or force_pad:
183191
image_proc = expand2square(image_data, tuple(int(x * 255) for x in self.image_processor.image_mean))
184192
t = self.image_processor.preprocess(image_proc, return_tensors="pt")["pixel_values"]
185193
elif image_aspect_ratio and (image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio):
@@ -194,16 +202,18 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
194202
elif t.ndim == 3:
195203
t = t.unsqueeze(0)
196204

197-
# 对齐实际视图数 K 与期望 token(可能是 K 或 K*patch_len)
198-
expected_token = img.token_num if getattr(img, "token_num", None) is not None else None
205+
# 对齐实际视图数 K 与期望视图数(anyres: Nx*Ny+1;否则:1)
199206
actual_k = t.shape[0]
200-
if expected_token is None or expected_token <= 0:
201-
expected_views = actual_k
207+
if (
208+
image_aspect_ratio and (image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio)
209+
) and not force_pad:
210+
crop_size = self.image_processor.crop_size["height"]
211+
grid_w, grid_h = get_anyres_image_grid_shape(
212+
(img.image_w, img.image_h), image_grid_pinpoints, crop_size
213+
)
214+
expected_views = int(grid_w * grid_h + 1)
202215
else:
203-
if expected_token >= patch_len and expected_token % patch_len == 0:
204-
expected_views = expected_token // patch_len
205-
else:
206-
expected_views = expected_token
216+
expected_views = 1
207217
if actual_k != expected_views:
208218
if actual_k % expected_views == 0:
209219
factor = actual_k // expected_views
@@ -219,26 +229,86 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
219229
pad = t[-1:].repeat(expected_views - actual_k, 1, 1, 1)
220230
t = torch.cat([t, pad], dim=0)
221231
img_tensors.append(t)
222-
# 最终视图数 K
223-
final_views = t.shape[0]
224-
# 对齐 patch 序列后的总 token 数
225-
img.token_num = final_views * patch_len
226232
else:
227233
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
228234

229-
# 本图对应的 token 数(视图 * patch_len)
230-
if isinstance(img_tensors[-1], torch.Tensor) and img_tensors[-1].dim() == 4:
231-
cur_num = img_tensors[-1].shape[0] * patch_len
232-
else:
233-
cur_num = patch_len
234-
valid_ids.append([valid_id, valid_id + cur_num])
235-
valid_id += cur_num
235+
# 暂不累加 valid_ids,待完成重组后依据真实长度填写
236236

237237
if len(img_tensors) <= 0:
238238
return None, [], []
239239
# 保证全部为4维后拼接
240240
img = torch.cat(img_tensors, dim=0)
241241
img = img.cuda()
242+
# 提取所有视图的 patch 序列嵌入(views * patch_len, hidden)
242243
all_img_embeds = self.forward(img)
243244

244-
return all_img_embeds, uuids, valid_ids
245+
# 将每张图的视图嵌入进行 spatial+unpad(+anyres_max) 重组,并追加换行列
246+
new_embeds: List[torch.Tensor] = []
247+
cur = 0
248+
for i, img in enumerate(images):
249+
# 计算本图视图数
250+
t = img_tensors[i]
251+
K = t.shape[0]
252+
# 取出本图的所有 view 的 patch 序列嵌入
253+
tokens_len = K * patch_len
254+
cur_views_embeds = all_img_embeds[cur : cur + tokens_len]
255+
cur += tokens_len
256+
257+
# 非 anyres 或多图/视频强制 pad:直接使用展平序列(K 通常为 1)
258+
force_pad = len(images) > 1
259+
aspect = getattr(self.image_processor, "image_aspect_ratio", None)
260+
if not aspect or ("anyres" not in str(aspect)) or force_pad or K <= 1:
261+
seq = cur_views_embeds
262+
new_embeds.append(seq)
263+
# 记录区间
264+
valid_ids.append([valid_id, valid_id + seq.shape[0]])
265+
valid_id += seq.shape[0]
266+
continue
267+
268+
# anyres 单图路径:
269+
# 切分 base 视图与其余视图
270+
base_feature = cur_views_embeds[:patch_len]
271+
rest = cur_views_embeds[patch_len:]
272+
# (K-1, patch_len, hidden)
273+
hidden = rest.shape[-1]
274+
rest = rest.view(K - 1, patch_len, hidden)
275+
276+
# 计算 Nx, Ny
277+
crop_size = self.image_processor.crop_size["height"]
278+
grid_w, grid_h = get_anyres_image_grid_shape((img.image_w, img.image_h), image_grid_pinpoints, crop_size)
279+
# (Ny, Nx, patch_side, patch_side, hidden)
280+
rest = rest.view(grid_w * grid_h, patch_side, patch_side, hidden)
281+
rest = rest.view(grid_h, grid_w, patch_side, patch_side, hidden)
282+
# (hidden, Ny, patch_side, Nx, patch_side) -> (hidden, H, W)
283+
rest = rest.permute(4, 0, 2, 1, 3).contiguous()
284+
H = grid_h * patch_side
285+
W = grid_w * patch_side
286+
rest = rest.view(hidden, H, W)
287+
288+
# anyres_max 下采样
289+
m = re.search(r"anyres_max_(\d+)", str(aspect))
290+
if m is not None:
291+
max_num_patches = int(m.group(1))
292+
times = (H * W) / (max_num_patches * patch_len)
293+
if times > 1.1:
294+
scale = (int(H // (times ** 0.5)), int(W // (times ** 0.5)))
295+
rest = F.interpolate(rest.unsqueeze(0), size=scale, mode="bilinear", align_corners=False)[0]
296+
H, W = rest.shape[1], rest.shape[2]
297+
298+
# 追加换行列(列数+1),换行列取 0 向量占位
299+
newline_col = torch.zeros((hidden, H, 1), device=rest.device, dtype=rest.dtype)
300+
rest = torch.cat([rest, newline_col], dim=2) # (hidden, H, W+1)
301+
# 展平成 (H*(W+1), hidden)
302+
rest = rest.flatten(1, 2).transpose(0, 1).contiguous()
303+
304+
# 拼接 base + 其余
305+
seq = torch.cat([base_feature, rest], dim=0)
306+
new_embeds.append(seq)
307+
308+
# 记录区间
309+
valid_ids.append([valid_id, valid_id + seq.shape[0]])
310+
valid_id += seq.shape[0]
311+
312+
# 拼接所有图的重组后嵌入
313+
all_new = torch.cat(new_embeds, dim=0)
314+
return all_new, uuids, valid_ids

lightllm/models/mineru2_qwen/model.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from ..mineru2_qwen.image_processing_mineru2 import Mineru2ImageProcessor
1313
from .image_processing_mineru2 import get_anyres_image_grid_shape
14+
import math
1415

1516
IMG_START_TOKEN = "<img>"
1617
IMG_END_TOKEN = "</img>"
@@ -61,22 +62,46 @@ def init_audioitem_extral_params(
6162
raise NotImplementedError
6263

6364
def get_image_token_length(self, img: ImageItem):
64-
# 切回 patch 序列:总token数 = 视图数 × 每视图patch数
65-
# 每视图patch数 = self.image_length = (image_size // patch_size) ** 2
65+
# 非 anyres:单视图,仅 base patch 序列
6666
patch_len = int(self.image_length)
67+
aspect_ratio = getattr(self.image_processor, "image_aspect_ratio", None)
68+
if not aspect_ratio or ("anyres" not in str(aspect_ratio)):
69+
return patch_len
6770

71+
# anyres:按 ref 的 spatial + unpad + anyres_max 逻辑计数
6872
crop_size = self.image_processor.crop_size["height"]
6973
grid_w, grid_h = get_anyres_image_grid_shape(
7074
(img.image_w, img.image_h), self.image_processor.image_grid_pinpoints, crop_size
7175
)
72-
views = int(grid_w * grid_h + 1)
73-
token_num = views * patch_len
74-
print(
75-
f"[debug] mineru2_tokenizer anyres img_size=({img.image_w},{img.image_h}) "
76-
f"crop={crop_size} grid=({grid_w},{grid_h}) views={views}"
77-
f" patch_len={patch_len} token_num={token_num}"
78-
)
79-
return token_num
76+
# base 视图(原图等比到 crop)
77+
base_tokens = patch_len
78+
patch_side = int(math.sqrt(patch_len))
79+
# h, w 为拼接后的整体网格尺寸(单位:patch)
80+
h = int(grid_h * patch_side)
81+
w = int(grid_w * patch_side)
82+
83+
new_h, new_w = h, w
84+
max_num_patches = None
85+
m = re.search(r"anyres_max_(\d+)", str(aspect_ratio))
86+
if m:
87+
max_num_patches = int(m.group(1))
88+
times = math.sqrt((h * w) / (max_num_patches * patch_len))
89+
if times > 1.1:
90+
new_h = int(new_h // times)
91+
new_w = int(new_w // times)
92+
# 每行追加换行 token,数量等于行数 new_h
93+
extra_tokens = int(new_h * (new_w + 1))
94+
total_tokens = int(base_tokens + extra_tokens)
95+
96+
print(f"[debug][spatial] P={patch_side}, N={patch_len}, Nx={grid_w}, Ny={grid_h}, crops={grid_w*grid_h}")
97+
if max_num_patches is not None:
98+
times = math.sqrt((h * w) / (max_num_patches * patch_len))
99+
print(
100+
f"[debug][spatial+unpad+anyres_max] h={h}, w={w}, "
101+
f"times={times:.4f}, h'={new_h}, w'={new_w}, newline={new_h}, extra_tokens~={extra_tokens}"
102+
)
103+
print(f"[debug][spatial] base_tokens={base_tokens}, extra_tokens={extra_tokens}, total_tokens={total_tokens}")
104+
return total_tokens
80105

81106
def get_audio_token_length(self, audio: AudioItem):
82107
raise NotImplementedError

0 commit comments

Comments
 (0)