Skip to content

Commit 9e270be

Browse files
committed
tokenizer has bug
1 parent 11d26b6 commit 9e270be

File tree

3 files changed

+62
-27
lines changed

3 files changed

+62
-27
lines changed

lightllm/models/mineru2_qwen/image_processing_mineru2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ def _preprocess(self, images):
161161
if isinstance(images, Image.Image):
162162
images = [images]
163163
else:
164-
# to adapt video data
165164
images = [to_numpy_array(image) for image in images]
166165
assert isinstance(images, list)
167166

lightllm/models/mineru2_qwen/mineru2_visual.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def encode(self, images: List[ImageItem]) -> Tuple[torch.Tensor, List[str], List
126126
print(f"[debug] mineru2_visual unsqueeze t.ndim: {t.ndim}, t.shape: {t.shape}")
127127
t = t.unsqueeze(0)
128128
img_tensors.append(t)
129+
img.token_num = t.shape[0]
129130
else:
130131
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
131132

lightllm/models/mineru2_qwen/model.py

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,31 @@
77
from lightllm.server.core.objs import SamplingParams
88
from lightllm.models.registry import ModelRegistry
99
from lightllm.models.qwen2.model import Qwen2TpPartModel
10-
from lightllm.models.qwen2_vl.vision_process import smart_resize
10+
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
11+
from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import InternVLLlamaPreAndPostLayerWeight
12+
from lightllm.models.internvl.img_process import get_image_patch
1113

1214
from ..mineru2_qwen.image_processing_mineru2 import Mineru2ImageProcessor
15+
from .image_processing_mineru2 import get_anyres_image_grid_shape
16+
17+
IMG_START_TOKEN = "<img>"
18+
IMG_END_TOKEN = "</img>"
19+
IMG_TOKEN = "<image>"
1320

1421

1522
class Mineru2QwenTokenizer(BaseMultiModalTokenizer):
1623
def __init__(self, tokenizer, model_cfg):
1724
super().__init__(tokenizer)
18-
self.image_token = model_cfg.get("image_token", "<image>")
19-
# for llava-v1.5-7b-hf model
25+
26+
self.image_token = model_cfg.get("image_token", IMG_TOKEN)
27+
self.img_token_index = model_cfg.get("image_token_index", 151646)
28+
29+
self.image_start_tag = IMG_START_TOKEN
30+
self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag)
31+
32+
self.image_end_tag = IMG_END_TOKEN
33+
self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag)
34+
2035
if "text_config" in model_cfg:
2136
patch_size = model_cfg["vision_config"]["patch_size"]
2237
image_size = model_cfg["vision_config"]["image_size"]
@@ -30,9 +45,12 @@ def __init__(self, tokenizer, model_cfg):
3045
default_img_size = int(vision_tower_match.group(3))
3146
image_size = model_cfg.get("img_size", default_img_size)
3247
image_size = model_cfg.get("mm_image_size", image_size)
33-
# (image_size // patch_size) ** 2: (384 // 14) ** 2 = 729
48+
49+
self.image_processor = Mineru2ImageProcessor(
50+
image_aspect_ratio=getattr(model_cfg, "image_aspect_ratio", None),
51+
image_grid_pinpoints=getattr(model_cfg, "image_grid_pinpoints", None),
52+
)
3453
self.image_length = (image_size // patch_size) ** 2
35-
self.skip_start = model_cfg.get("skip_start", True)
3654

3755
def init_imageitem_extral_params(
3856
self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
@@ -52,30 +70,47 @@ def get_audio_token_length(self, audio: AudioItem):
5270

5371
# only change the impl of the encode func:
5472
def encode(self, prompt, multimodal_params: MultimodalParams = None, add_special_tokens: bool = True):
55-
image_token_id = getattr(self, "image_token_index", 151646)
56-
image_token = self.image_token
57-
58-
text_parts = prompt.split(image_token)
59-
token_ids = []
60-
image_offsets = []
61-
offset = 0
62-
for i, part in enumerate(text_parts):
63-
part_ids = self.tokenizer.encode(part, add_special_tokens=(add_special_tokens if i == 0 else False))
64-
token_ids.extend(part_ids)
65-
offset += len(part_ids)
66-
if i < len(text_parts) - 1:
67-
token_ids.append(image_token_id)
68-
image_offsets.append(offset)
69-
offset += 1
70-
71-
# 记录image_offsets方便后处理
72-
if multimodal_params is not None:
73-
multimodal_params.image_offsets = image_offsets
74-
# multimodal_params.image_pad_len 可在后处理时补充
75-
return token_ids
73+
# TEXT<image>TEXT<image>TEXT --> TEXT<img></img>TEXT<img></img>TEXT
74+
image_tokens = IMG_START_TOKEN + IMG_END_TOKEN
75+
if multimodal_params is None:
76+
return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
77+
image_count = len(multimodal_params.images)
78+
prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count)
79+
80+
origin_ids = self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
81+
# <img></img> --> <img>id,id+1...id+num</img>
82+
input_ids = []
83+
image_id = 0
84+
start_idx = 0
85+
while True:
86+
try:
87+
start_idx = origin_ids.index(self.image_start_id, start_idx)
88+
if start_idx + 1 >= len(origin_ids):
89+
break
90+
if origin_ids[start_idx + 1] == self.image_end_id:
91+
input_ids.extend(origin_ids[: start_idx + 1])
92+
token_id = multimodal_params.images[image_id].token_id
93+
token_num = multimodal_params.images[image_id].token_num
94+
input_ids.extend(range(token_id, token_id + token_num))
95+
input_ids.append(self.image_end_id)
96+
origin_ids = origin_ids[start_idx + 2 :]
97+
start_idx = 0
98+
image_id += 1
99+
else:
100+
raise ValueError("image token error")
101+
except ValueError:
102+
break
103+
input_ids.extend(origin_ids[start_idx:])
104+
return input_ids
76105

77106

78107
@ModelRegistry("mineru2_qwen", is_multimodal=True)
79108
class Mineru2QwenForCausalLM(Qwen2TpPartModel):
109+
# weight class
110+
pre_and_post_weight_class = InternVLLlamaPreAndPostLayerWeight
111+
112+
# infer class
113+
pre_layer_infer_class = LlamaMultimodalPreLayerInfer
114+
80115
def __init__(self, kvargs):
81116
super().__init__(kvargs)

0 commit comments

Comments
 (0)