Skip to content

Commit 80558f2

Browse files
committed
adapt mineru2 visual
1 parent 696500d commit 80558f2

File tree

6 files changed

+141
-525
lines changed

6 files changed

+141
-525
lines changed

lightllm/models/mineru2_qwen/image_processing_mineru2.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import ast
2-
import math
32
import re
43
from functools import partial, reduce
54
from typing import Dict, Optional, Union
@@ -94,33 +93,6 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
9493
return width // patch_size, height // patch_size
9594

9695

97-
# This functions is not used.
98-
def resize_and_pad_image(image, target_resolution):
99-
original_width, original_height = image.size
100-
target_width, target_height = target_resolution
101-
102-
scale_w = target_width / original_width
103-
scale_h = target_height / original_height
104-
105-
if scale_w < scale_h:
106-
new_width = target_width
107-
new_height = min(math.ceil(original_height * scale_w), target_height)
108-
else:
109-
new_height = target_height
110-
new_width = min(math.ceil(original_width * scale_h), target_width)
111-
112-
# Resize the image
113-
resized_image = image.resize((new_width, new_height))
114-
115-
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
116-
paste_x = (target_width - new_width) // 2
117-
paste_y = (target_height - new_height) // 2
118-
new_image.paste(resized_image, (paste_x, paste_y))
119-
120-
return new_image
121-
122-
123-
# DIFFERENT from sglang.srt.mm_utils.process_anyres_image
12496
def process_anyres_image(image, processor, grid_pinpoints):
12597
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
12698
patch_size = processor.crop_size["height"]
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import re
2+
import os
3+
import json
4+
5+
from typing import List
6+
from io import BytesIO
7+
from PIL import Image
8+
from safetensors import safe_open
9+
10+
import torch
11+
import torch.nn as nn
12+
from transformers import (
13+
CLIPVisionModel,
14+
CLIPVisionConfig,
15+
SiglipVisionConfig,
16+
SiglipVisionModel,
17+
)
18+
19+
from .configuration_mineru2 import Mineru2QwenConfig
20+
from .image_processing_mineru2 import Mineru2ImageProcessor
21+
22+
from lightllm.server.multimodal_params import ImageItem
23+
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
24+
from lightllm.utils.log_utils import init_logger
25+
26+
logger = init_logger(__name__)
27+
28+
29+
def build_vision_tower(config: Mineru2QwenConfig):
30+
vision_tower = getattr(config, "mm_vision_tower", getattr(config, "vision_tower", ""))
31+
model_path = getattr(config, "_name_or_path", "")
32+
33+
if "clip" in vision_tower.lower():
34+
if model_path:
35+
vision_config = CLIPVisionConfig.from_pretrained(f"{model_path}/{vision_tower}")
36+
return CLIPVisionModel(vision_config)
37+
else:
38+
vision_config = CLIPVisionConfig.from_pretrained(vision_tower)
39+
return CLIPVisionModel(vision_config)
40+
elif "siglip" in vision_tower.lower():
41+
if model_path:
42+
vision_config = SiglipVisionConfig.from_pretrained(f"{model_path}/{vision_tower}")
43+
return SiglipVisionModel(vision_config)
44+
else:
45+
vision_config = SiglipVisionConfig.from_pretrained(vision_tower)
46+
return SiglipVisionModel(vision_config)
47+
else:
48+
raise ValueError(f"Unknown vision tower: {model_path}")
49+
50+
51+
def build_vision_projector(config: Mineru2QwenConfig):
52+
projector_type = getattr(config, "mm_projector_type", "linear")
53+
54+
if projector_type == "linear":
55+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
56+
57+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
58+
if mlp_gelu_match:
59+
mlp_depth = int(mlp_gelu_match.group(1))
60+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
61+
for _ in range(1, mlp_depth):
62+
modules.append(nn.GELU())
63+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
64+
return nn.Sequential(*modules)
65+
66+
if projector_type == "identity":
67+
return nn.Identity()
68+
69+
raise ValueError(f"Unknown projector type: {projector_type}")
70+
71+
72+
class Mineru2VisionModel:
73+
def __init__(self):
74+
pass
75+
76+
def load_model(self, weight_dir):
77+
# config_file = os.path.join(weight_dir, "config.json")
78+
vision_config = Mineru2QwenConfig.from_pretrained(weight_dir)
79+
80+
self.vision_tower = build_vision_tower(vision_config)
81+
self.projector = build_vision_projector(vision_config)
82+
self.image_processor = Mineru2ImageProcessor()
83+
84+
def forward(self, x):
85+
return self.projector(self.vision_tower(x))
86+
87+
def encode(self, images: List[ImageItem]):
88+
img_tensors = []
89+
uuids = []
90+
valid_id = 0
91+
valid_ids = []
92+
93+
for i, img in enumerate(images):
94+
if isinstance(img, ImageItem):
95+
uuids.append(img.uuid)
96+
image_data = read_shm(get_shm_name_data(img.uuid))
97+
image_data = Image.open(BytesIO(image_data)).convert("RGB")
98+
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
99+
img_tensors.append(t)
100+
else:
101+
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
102+
103+
cur_num = img_tensors[-1].shape[0]
104+
valid_ids.append([valid_id, valid_id + cur_num])
105+
valid_id += cur_num
106+
107+
if len(img_tensors) <= 0:
108+
return None
109+
110+
img = torch.cat(img_tensors, dim=0)
111+
all_img_embeds = self.forward(img)
112+
113+
return all_img_embeds, uuids, valid_ids
Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
from lightllm.models.qwen2.model import Qwen2TpPartModel
22
from lightllm.models.registry import ModelRegistry
3-
from .configuration_mineru2 import Mineru2QwenConfig
43

54

65
@ModelRegistry("mineru2_qwen", is_multimodal=True)
76
class Mineru2QwenForCausalLM(Qwen2TpPartModel):
8-
# a new config class is not necessary
9-
config_class = Mineru2QwenConfig
10-
117
def __init__(self, kvargs):
128
super().__init__(kvargs)

0 commit comments

Comments
 (0)