Skip to content

Commit 339d98e

Browse files
sangchengmengshihaobai
authored andcommitted
[fix] vit0331
1 parent 835ac10 commit 339d98e

File tree

14 files changed

+80
-78
lines changed

14 files changed

+80
-78
lines changed

lightllm/models/internvl/internvl_visual.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torchvision import transforms as T
99
from torchvision.transforms.functional import InterpolationMode
1010
from transformers import AutoModel, AutoTokenizer
11+
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
1112
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
1213
from io import BytesIO
1314
from lightllm.models.internvl.img_process import load_image
@@ -43,21 +44,21 @@ def load_model(self, weight_dir):
4344
def cuda(self):
4445
return self
4546

46-
def encode(self, image_uuids: List):
47+
def encode(self, images: List[ImageItem]):
4748
img_tensors = []
4849
valid_ids = []
4950
valid_id = 0
5051
uuids = []
5152

52-
for i, url in enumerate(image_uuids):
53-
if isinstance(url, int):
54-
uuids.append(url)
55-
image_data = read_shm(get_shm_name_data(url))
53+
for i, img in enumerate(images):
54+
if isinstance(img, ImageItem):
55+
uuids.append(img.uuid)
56+
image_data = read_shm(get_shm_name_data(img.uuid))
5657
image_data = Image.open(BytesIO(image_data))
57-
t = self.load_image_func(image_data)
58+
t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"])
5859
img_tensors.append(t)
5960
else:
60-
raise Exception("Unsupport input types: {} for {}".format(type(url), url))
61+
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
6162

6263
cur_num = img_tensors[-1].shape[0]
6364
valid_ids.append([valid_id, valid_id + cur_num])

lightllm/models/internvl/model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,20 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
4040
self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag)
4141
self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"])
4242

43-
def get_image_token_length(self, img: ImageItem, max_num):
43+
def init_imageItem_extral_params(self, img: ImageItem, num_images):
44+
if img.extra_params["image_patch_max_num"] > 0:
45+
return
46+
if num_images == 1:
47+
img.extra_params["image_patch_max_num"] = 12
48+
elif num_images > 1 and num_images <= 6:
49+
img.extra_params["image_patch_max_num"] = 6
50+
elif num_images > 6:
51+
img.extra_params["image_patch_max_num"] = 0
52+
return
53+
54+
def get_image_token_length(self, img: ImageItem):
4455
return (
45-
self.get_image_patch_func(img.image_w, img.image_h, max_num=max_num, use_thumbnail=True) * self.image_length
56+
self.get_image_patch_func(img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True) * self.image_length
4657
)
4758

4859
# only change the impl of the encode func:

lightllm/models/llava/llava_visual.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import List, Union
77
from safetensors import safe_open
88
from io import BytesIO
9+
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
910
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
1011
from lightllm.utils.log_utils import init_logger
1112

@@ -123,16 +124,16 @@ def forward(self, x):
123124
x = x.view(B, L, -1)
124125
return x
125126

126-
def encode(self, image_uuids: List):
127+
def encode(self, images: List[ImageItem]):
127128
img_tensors = []
128129
uuids = []
129130
valid_id = 0
130131
valid_ids = []
131132

132-
for i, item in enumerate(image_uuids):
133-
if isinstance(item, int):
134-
uuids.append(item)
135-
image_data = read_shm(get_shm_name_data(item))
133+
for i, img in enumerate(images):
134+
if isinstance(img, ImageItem):
135+
uuids.append(img.uuid)
136+
image_data = read_shm(get_shm_name_data(img.uuid))
136137
image_data = Image.open(BytesIO(image_data)).convert("RGB")
137138
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
138139
img_tensors.append(t)

lightllm/models/llava/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def __init__(self, tokenizer, model_cfg):
3333
self.image_length = (image_size // patch_size) ** 2
3434
self.skip_start = model_cfg.get("skip_start", True)
3535

36+
def init_imageItem_extral_params(self, img: ImageItem, num_images):
37+
return
3638
def get_image_token_length(self, img: ImageItem):
3739
return self.image_length
3840

lightllm/models/qwen2_vl/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def __init__(self, tokenizer=None, image_processor=None, **kwargs):
3131
self.image_end_id = kwargs["model_cfg"]["vision_end_token_id"]
3232
self.image_token_id = kwargs["model_cfg"]["image_token_id"]
3333

34+
def init_imageItem_extral_params(self, img: ImageItem, num_images):
35+
return
36+
3437
def get_image_token_length(self, img: ImageItem):
3538
width = img.image_w
3639
height = img.image_h

lightllm/models/qwen2_vl/qwen2_visual.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from transformers import AutoProcessor
4242
from safetensors import safe_open
4343
from transformers.utils import TensorType
44+
from lightllm.server.multimodal_params import MultimodalParams,ImageItem
4445
from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor
4546

4647

@@ -425,17 +426,17 @@ def load_model(self, weight_dir):
425426

426427
self.load_state_dict(weight_dict)
427428

428-
def encode(self, image_uuids: List):
429+
def encode(self, images: List[ImageItem]):
429430
img_tensors = []
430431
valid_ids = []
431432
valid_id = 0
432433
img_grids = []
433434
uuids = []
434435

435-
for i, url in enumerate(image_uuids):
436-
if isinstance(url, int):
437-
uuids.append(url)
438-
image_data = read_shm(get_shm_name_data(url))
436+
for i, img in enumerate(images):
437+
if isinstance(img, ImageItem):
438+
uuids.append(img.uuid)
439+
image_data = read_shm(get_shm_name_data(img.uuid))
439440
image_data = Image.open(BytesIO(image_data))
440441
image_data = get_image(image_data)
441442
image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt")

lightllm/models/qwen_vl/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def __init__(self, tokenizer, model_cfg):
1919
# <imgpad>: 151859
2020
self.image_length = model_cfg["visual"].get("n_queries", 256)
2121

22+
def init_imageItem_extral_params(self, img: ImageItem, num_images):
23+
return
24+
2225
def _list_find(self, input_list, target, start_idx):
2326
cur_list = input_list[start_idx:]
2427
if target in cur_list:

lightllm/models/vit/model.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from lightllm.models.vit.layer_weights.pre_and_post_layer_weight import ViTPreAndPostLayerWeight
88
from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight
99
from lightllm.models.vit.layer_weights.hf_load_utils import load_hf_weights
10+
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
1011
from lightllm.common.build_utils import repair_config
1112
from lightllm.utils.log_utils import init_logger
1213
from lightllm.models.vit import get_load_image_func
@@ -135,21 +136,20 @@ def forward(self, pixel_values):
135136
return input_embs
136137

137138
@torch.no_grad()
138-
def encode(self, image_uuids: List, max_num_list: List):
139+
def encode(self, images: List[ImageItem]):
139140
img_tensors = []
140141
valid_ids = []
141142
valid_id = 0
142143
uuids = []
143-
for i, url in enumerate(image_uuids):
144-
if isinstance(url, int):
145-
uuids.append(url)
146-
image_data = read_shm(get_shm_name_data(url))
144+
for i, img in enumerate(images):
145+
if isinstance(img, ImageItem):
146+
uuids.append(img.uuid)
147+
image_data = read_shm(get_shm_name_data(img.uuid))
147148
image_data = Image.open(BytesIO(image_data))
148-
max_num = max_num_list[i]
149-
t = self.load_image_func(image_data, max_num=max_num)
149+
t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"])
150150
img_tensors.append(t)
151151
else:
152-
raise Exception("Unsupport input types: {} for {}".format(type(url), url))
152+
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
153153

154154
cur_num = img_tensors[-1].shape[0]
155155
valid_ids.append([valid_id, valid_id + cur_num])
@@ -160,7 +160,6 @@ def encode(self, image_uuids: List, max_num_list: List):
160160

161161
imgs = torch.cat(img_tensors, dim=0)
162162
pixel_values = imgs.cuda().to(dtype=self.data_type)
163-
print(pixel_values.shape, pixel_values.dtype)
164163
all_img_embeds = self.forward(pixel_values)
165164
return all_img_embeds, uuids, valid_ids
166165

lightllm/server/embed_cache/impl/naive_memory_cache.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
class Record(object):
1515
id: int
1616
md5sum: str
17-
max_num: int
1817
ref: int
1918
data: bool
2019
embed: bool
@@ -70,14 +69,11 @@ def alloc(self, md5sum: str, token_num: int) -> dict:
7069
self._clear()
7170
if self.occupied >= self.capacity:
7271
return None
73-
_, max_num_str = md5sum.rsplit("_", 1)
74-
max_num = int(max_num_str)
7572
id = uuid.uuid1()
7673
id = id.int
7774
record = Record(
7875
id=id,
7976
md5sum=md5sum,
80-
max_num=max_num,
8177
ref=1,
8278
data=False,
8379
embed=False,
@@ -113,7 +109,4 @@ def set_item_embed(self, id: int) -> None:
113109
self._records[id].embed = True
114110

115111
def get_item_embed(self, id: int) -> bool:
116-
return self._records[id].embed
117-
118-
def get_max_num(self, id: int) -> int:
119-
return self._records[id].max_num
112+
return self._records[id].embed

lightllm/server/embed_cache/manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,6 @@ def exposed_get_item_embed(self, id: int) -> bool:
4848
id = obtain(id)
4949
return self._impl.get_item_embed(id=id)
5050

51-
def exposed_get_max_num(self, id: int) -> int:
52-
id = obtain(id)
53-
return self._impl.get_max_num(id=id)
54-
5551

5652
def start_cache_manager(port: int, args, pipe_writer):
5753
# 注册graceful 退出的处理

0 commit comments

Comments
 (0)