Skip to content

Commit 9f9a7d3

Browse files
author
wanghao7
committed
Merge remote-tracking branch 'origin/main' into prefill_overlap
2 parents 1b3d4ff + 750957f commit 9f9a7d3

File tree

24 files changed

+425
-141
lines changed

24 files changed

+425
-141
lines changed

lightllm/models/internvl/img_process.py

Lines changed: 39 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,56 @@
11
import torch
2-
import torch.nn.functional as F
3-
from PIL import Image
2+
import math
43
from torchvision import transforms as T
54
from torchvision.transforms.functional import InterpolationMode
65

76

8-
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
7+
def find_closest_aspect_ratio(width, height, min_num=1, max_num=6, image_size=448):
98
"""
109
Find the closest aspect ratio from a list of target ratios to match the given aspect ratio.
1110
If the difference is the same, use the area to decide the better ratio.
1211
"""
13-
best_ratio_diff = float("inf")
14-
best_ratio = (1, 1)
15-
area = width * height
16-
for ratio in target_ratios:
17-
target_aspect_ratio = ratio[0] / ratio[1]
18-
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
19-
if ratio_diff < best_ratio_diff:
20-
best_ratio_diff = ratio_diff
21-
best_ratio = ratio
22-
elif ratio_diff == best_ratio_diff:
23-
# Compare areas to decide the better ratio when the difference is the same
24-
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
25-
best_ratio = ratio
26-
return best_ratio
27-
28-
29-
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
12+
assert min_num == 1
13+
log_ratio = math.log(width / height)
14+
ratio = width * height / (image_size * image_size)
15+
multiple = min(math.ceil(ratio), max_num)
16+
if multiple <= 1:
17+
return [1, 1]
18+
candidate_split_grids_nums = []
19+
for i in [multiple - 1, multiple, multiple + 1]:
20+
if i > max_num:
21+
continue
22+
candidate_split_grids_nums.append(i)
23+
24+
candidate_grids = []
25+
for split_grids_nums in candidate_split_grids_nums:
26+
m = 1
27+
while m <= split_grids_nums:
28+
if split_grids_nums % m == 0:
29+
candidate_grids.append([m, split_grids_nums // m])
30+
m += 1
31+
best_grid = [1, 1]
32+
min_error = float("inf")
33+
for grid in candidate_grids:
34+
error = abs(log_ratio - math.log(grid[0] / grid[1]))
35+
if error < min_error:
36+
best_grid = grid
37+
min_error = error
38+
39+
return best_grid
40+
41+
42+
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=True):
3043
"""
3144
Preprocess the image dynamically by finding the closest aspect ratio,
3245
resizing the image, and splitting it into smaller blocks.
3346
Optionally add a thumbnail version of the image.
3447
"""
35-
orig_width, orig_height = image.size
36-
aspect_ratio = orig_width / orig_height
37-
38-
# Calculate the existing image aspect ratio
39-
target_ratios = set(
40-
(i, j)
41-
for n in range(min_num, max_num + 1)
42-
for i in range(1, n + 1)
43-
for j in range(1, n + 1)
44-
if i * j <= max_num and i * j >= min_num
45-
)
46-
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
47-
48-
# Find the closest aspect ratio to the target
49-
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
50-
51-
# Calculate the target width and height
48+
original_width, original_height = image.size
49+
target_aspect_ratio = find_closest_aspect_ratio(original_width, original_height, min_num, max_num, image_size)
5250
target_width = image_size * target_aspect_ratio[0]
5351
target_height = image_size * target_aspect_ratio[1]
5452
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
55-
56-
# Resize the image to the target dimensions
53+
# resize the image
5754
resized_img = image.resize((target_width, target_height))
5855
processed_images = []
5956
for i in range(blocks):
@@ -63,40 +60,22 @@ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnai
6360
((i % (target_width // image_size)) + 1) * image_size,
6461
((i // (target_width // image_size)) + 1) * image_size,
6562
)
66-
# Split the image into blocks
63+
# split the image
6764
split_img = resized_img.crop(box)
6865
processed_images.append(split_img)
69-
7066
assert len(processed_images) == blocks
71-
72-
# Optionally add a thumbnail version of the image
7367
if use_thumbnail and len(processed_images) != 1:
7468
thumbnail_img = image.resize((image_size, image_size))
7569
processed_images.append(thumbnail_img)
76-
7770
return processed_images
7871

7972

80-
def get_image_patch(orign_width, orign_height, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
73+
def get_image_patch(orign_width, orign_height, min_num=1, max_num=6, image_size=448, use_thumbnail=True):
8174
"""
8275
Calculate the number of image patches based on the closest aspect ratio
8376
and the given width and height of the original image.
8477
"""
85-
aspect_ratio = orign_width / orign_height
86-
87-
# calculate the existing image aspect ratio
88-
target_ratios = set(
89-
(i, j)
90-
for n in range(min_num, max_num + 1)
91-
for i in range(1, n + 1)
92-
for j in range(1, n + 1)
93-
if i * j <= max_num and i * j >= min_num
94-
)
95-
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
96-
97-
# find the closest aspect ratio to the target
98-
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orign_width, orign_height, image_size)
99-
78+
target_aspect_ratio = find_closest_aspect_ratio(orign_width, orign_height, min_num, max_num, image_size)
10079
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
10180
if use_thumbnail and blocks != 1:
10281
blocks += 1

lightllm/models/internvl/internvl_visual.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torchvision import transforms as T
99
from torchvision.transforms.functional import InterpolationMode
1010
from transformers import AutoModel, AutoTokenizer
11+
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
1112
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
1213
from io import BytesIO
1314
from lightllm.models.internvl.img_process import load_image
@@ -43,21 +44,21 @@ def load_model(self, weight_dir):
4344
def cuda(self):
4445
return self
4546

46-
def encode(self, image_uuids: List):
47+
def encode(self, images: List[ImageItem]):
4748
img_tensors = []
4849
valid_ids = []
4950
valid_id = 0
5051
uuids = []
5152

52-
for i, url in enumerate(image_uuids):
53-
if isinstance(url, int):
54-
uuids.append(url)
55-
image_data = read_shm(get_shm_name_data(url))
53+
for i, img in enumerate(images):
54+
if isinstance(img, ImageItem):
55+
uuids.append(img.uuid)
56+
image_data = read_shm(get_shm_name_data(img.uuid))
5657
image_data = Image.open(BytesIO(image_data))
57-
t = self.load_image_func(image_data)
58+
t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"])
5859
img_tensors.append(t)
5960
else:
60-
raise Exception("Unsupport input types: {} for {}".format(type(url), url))
61+
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
6162

6263
cur_num = img_tensors[-1].shape[0]
6364
valid_ids.append([valid_id, valid_id + cur_num])

lightllm/models/internvl/model.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
InternVLLlamaPreAndPostLayerWeight,
1212
InternVLPhi3PreAndPostLayerWeight,
1313
)
14+
from lightllm.server.core.objs import SamplingParams
1415
from lightllm.models.internvl.layer_weights.pre_and_post_layer_weight import InternVLInternlm2PreAndPostLayerWeight
1516
from lightllm.models.llava.llava_visual import LlavaVisionModel
1617

@@ -40,8 +41,32 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
4041
self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag)
4142
self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"])
4243

44+
def init_imageItem_extral_params(
45+
self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
46+
):
47+
if sampling_params.image_max_patch_num >= 0:
48+
img.extra_params["image_patch_max_num"] = sampling_params.image_max_patch_num
49+
return
50+
elif os.getenv("MAX_PATCH_NUM"):
51+
img.extra_params["image_patch_max_num"] = int(os.getenv("MAX_PATCH_NUM"))
52+
return
53+
else:
54+
num_images = len(multi_params.images)
55+
if num_images == 1:
56+
img.extra_params["image_patch_max_num"] = 12
57+
elif num_images > 1 and num_images <= 6:
58+
img.extra_params["image_patch_max_num"] = 6
59+
elif num_images > 6:
60+
img.extra_params["image_patch_max_num"] = 0
61+
return
62+
4363
def get_image_token_length(self, img: ImageItem):
44-
return self.get_image_patch_func(img.image_w, img.image_h, use_thumbnail=True) * self.image_length
64+
return (
65+
self.get_image_patch_func(
66+
img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True
67+
)
68+
* self.image_length
69+
)
4570

4671
# only change the impl of the encode func:
4772
def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):

lightllm/models/llava/llava_visual.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import List, Union
77
from safetensors import safe_open
88
from io import BytesIO
9+
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
910
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
1011
from lightllm.utils.log_utils import init_logger
1112

@@ -123,21 +124,21 @@ def forward(self, x):
123124
x = x.view(B, L, -1)
124125
return x
125126

126-
def encode(self, image_uuids: List):
127+
def encode(self, images: List[ImageItem]):
127128
img_tensors = []
128129
uuids = []
129130
valid_id = 0
130131
valid_ids = []
131132

132-
for i, item in enumerate(image_uuids):
133-
if isinstance(item, int):
134-
uuids.append(item)
135-
image_data = read_shm(get_shm_name_data(item))
133+
for i, img in enumerate(images):
134+
if isinstance(img, ImageItem):
135+
uuids.append(img.uuid)
136+
image_data = read_shm(get_shm_name_data(img.uuid))
136137
image_data = Image.open(BytesIO(image_data)).convert("RGB")
137138
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
138139
img_tensors.append(t)
139140
else:
140-
raise Exception("Unsupport input types: {} for {}".format(type(item), item))
141+
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
141142

142143
cur_num = img_tensors[-1].shape[0]
143144
valid_ids.append([valid_id, valid_id + cur_num])

lightllm/models/llava/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
77
from lightllm.models.llava.layer_weights.pre_and_post_layer_weight import LlavaPreAndPostLayerWeight
88
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
9+
from lightllm.server.core.objs import SamplingParams
910
from lightllm.common.build_utils import repair_config
1011
from transformers import AutoConfig
1112

@@ -33,6 +34,11 @@ def __init__(self, tokenizer, model_cfg):
3334
self.image_length = (image_size // patch_size) ** 2
3435
self.skip_start = model_cfg.get("skip_start", True)
3536

37+
def init_imageItem_extral_params(
38+
self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
39+
):
40+
return
41+
3642
def get_image_token_length(self, img: ImageItem):
3743
return self.image_length
3844

lightllm/models/qwen2_vl/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from transformers.feature_extraction_utils import BatchFeature
88
from transformers.image_utils import ImageInput
99
from transformers.processing_utils import ProcessorMixin
10+
from lightllm.server.core.objs import SamplingParams
1011
from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
1112
from typing import List, Optional, Union
1213
from transformers.utils import TensorType, logging
@@ -31,6 +32,11 @@ def __init__(self, tokenizer=None, image_processor=None, **kwargs):
3132
self.image_end_id = kwargs["model_cfg"]["vision_end_token_id"]
3233
self.image_token_id = kwargs["model_cfg"]["image_token_id"]
3334

35+
def init_imageItem_extral_params(
36+
self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
37+
):
38+
return
39+
3440
def get_image_token_length(self, img: ImageItem):
3541
width = img.image_w
3642
height = img.image_h

lightllm/models/qwen2_vl/qwen2_visual.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from transformers import AutoProcessor
4242
from safetensors import safe_open
4343
from transformers.utils import TensorType
44+
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
4445
from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor
4546

4647

@@ -425,17 +426,17 @@ def load_model(self, weight_dir):
425426

426427
self.load_state_dict(weight_dict)
427428

428-
def encode(self, image_uuids: List):
429+
def encode(self, images: List[ImageItem]):
429430
img_tensors = []
430431
valid_ids = []
431432
valid_id = 0
432433
img_grids = []
433434
uuids = []
434435

435-
for i, url in enumerate(image_uuids):
436-
if isinstance(url, int):
437-
uuids.append(url)
438-
image_data = read_shm(get_shm_name_data(url))
436+
for i, img in enumerate(images):
437+
if isinstance(img, ImageItem):
438+
uuids.append(img.uuid)
439+
image_data = read_shm(get_shm_name_data(img.uuid))
439440
image_data = Image.open(BytesIO(image_data))
440441
image_data = get_image(image_data)
441442
image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt")
@@ -444,7 +445,7 @@ def encode(self, image_uuids: List):
444445
img_tensors.append(pixel_values)
445446
img_grids.append(image_grid_thw)
446447
else:
447-
raise Exception("Unsupport input types: {} for {}".format(type(url), url))
448+
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
448449

449450
# must devide merge_length
450451
cur_num = img_tensors[-1].shape[0] // (self.spatial_merge_size ** 2)

lightllm/models/qwen_vl/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import numpy as np
33
import unicodedata
4+
from lightllm.server.core.objs import SamplingParams
45
from lightllm.models.qwen.model import QWenTpPartModel
56
from .layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
67
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
@@ -19,6 +20,11 @@ def __init__(self, tokenizer, model_cfg):
1920
# <imgpad>: 151859
2021
self.image_length = model_cfg["visual"].get("n_queries", 256)
2122

23+
def init_imageItem_extral_params(
24+
self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
25+
):
26+
return
27+
2228
def _list_find(self, input_list, target, start_idx):
2329
cur_list = input_list[start_idx:]
2430
if target in cur_list:

lightllm/models/vit/layer_infer/post_layer_infer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch.distributed as dist
44
from lightllm.models.vit.layer_weights.pre_and_post_layer_weight import ViTPreAndPostLayerWeight
55
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
6+
from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd
67

78

89
class ViTPostLayerInfer:
@@ -44,7 +45,8 @@ def forward(self, vit_embeds, layer_weight: ViTPreAndPostLayerWeight):
4445
layer_weight.mlp1_1_bias_, vit_embeds_norm.view(-1, vit_embeds_norm.shape[-1]), layer_weight.mlp1_1_weight_
4546
)
4647

47-
vit_embeds_gelu = torch.nn.functional.gelu(vit_embeds_1)
48+
# vit_embeds_gelu = torch.nn.functional.gelu(vit_embeds_1)
49+
vit_embeds_gelu = gelu_fwd(vit_embeds_1)
4850

4951
vit_embeds_out = torch.addmm(
5052
layer_weight.mlp1_3_bias_,

0 commit comments

Comments
 (0)