Skip to content
Merged
99 changes: 39 additions & 60 deletions lightllm/models/internvl/img_process.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,56 @@
import torch
import torch.nn.functional as F
from PIL import Image
import math
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
def find_closest_aspect_ratio(width, height, min_num=1, max_num=6, image_size=448):
"""
Find the closest aspect ratio from a list of target ratios to match the given aspect ratio.
If the difference is the same, use the area to decide the better ratio.
"""
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
# Compare areas to decide the better ratio when the difference is the same
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
assert min_num == 1
log_ratio = math.log(width / height)
ratio = width * height / (image_size * image_size)
multiple = min(math.ceil(ratio), max_num)
if multiple <= 1:
return [1, 1]
candidate_split_grids_nums = []
for i in [multiple - 1, multiple, multiple + 1]:
if i > max_num:
continue
candidate_split_grids_nums.append(i)

candidate_grids = []
for split_grids_nums in candidate_split_grids_nums:
m = 1
while m <= split_grids_nums:
if split_grids_nums % m == 0:
candidate_grids.append([m, split_grids_nums // m])
m += 1
best_grid = [1, 1]
min_error = float("inf")
for grid in candidate_grids:
error = abs(log_ratio - math.log(grid[0] / grid[1]))
if error < min_error:
best_grid = grid
min_error = error

return best_grid


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=True):
"""
Preprocess the image dynamically by finding the closest aspect ratio,
resizing the image, and splitting it into smaller blocks.
Optionally add a thumbnail version of the image.
"""
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height

# Calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

# Find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)

# Calculate the target width and height
original_width, original_height = image.size
target_aspect_ratio = find_closest_aspect_ratio(original_width, original_height, min_num, max_num, image_size)
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

# Resize the image to the target dimensions
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
Expand All @@ -63,40 +60,22 @@ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnai
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# Split the image into blocks
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)

assert len(processed_images) == blocks

# Optionally add a thumbnail version of the image
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)

return processed_images


def get_image_patch(orign_width, orign_height, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
def get_image_patch(orign_width, orign_height, min_num=1, max_num=6, image_size=448, use_thumbnail=True):
"""
Calculate the number of image patches based on the closest aspect ratio
and the given width and height of the original image.
"""
aspect_ratio = orign_width / orign_height

# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orign_width, orign_height, image_size)

target_aspect_ratio = find_closest_aspect_ratio(orign_width, orign_height, min_num, max_num, image_size)
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
if use_thumbnail and blocks != 1:
blocks += 1
Expand Down
15 changes: 8 additions & 7 deletions lightllm/models/internvl/internvl_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from io import BytesIO
from lightllm.models.internvl.img_process import load_image
Expand Down Expand Up @@ -43,21 +44,21 @@ def load_model(self, weight_dir):
def cuda(self):
return self

def encode(self, image_uuids: List):
def encode(self, images: List[ImageItem]):
img_tensors = []
valid_ids = []
valid_id = 0
uuids = []

for i, url in enumerate(image_uuids):
if isinstance(url, int):
uuids.append(url)
image_data = read_shm(get_shm_name_data(url))
for i, img in enumerate(images):
if isinstance(img, ImageItem):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data))
t = self.load_image_func(image_data)
t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"])
img_tensors.append(t)
else:
raise Exception("Unsupport input types: {} for {}".format(type(url), url))
raise Exception("Unsupport input types: {} for {}".format(type(img), img))

cur_num = img_tensors[-1].shape[0]
valid_ids.append([valid_id, valid_id + cur_num])
Expand Down
15 changes: 14 additions & 1 deletion lightllm/models/internvl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,21 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag)
self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"])

def init_imageItem_extral_params(self, img: ImageItem, num_images):
if img.extra_params["image_patch_max_num"] > 0:
return
if num_images == 1:
img.extra_params["image_patch_max_num"] = 12
elif num_images > 1 and num_images <= 6:
img.extra_params["image_patch_max_num"] = 6
elif num_images > 6:
img.extra_params["image_patch_max_num"] = 0
return

def get_image_token_length(self, img: ImageItem):
return self.get_image_patch_func(img.image_w, img.image_h, use_thumbnail=True) * self.image_length
return (
self.get_image_patch_func(img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True) * self.image_length
)

# only change the impl of the encode func:
def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
Expand Down
11 changes: 6 additions & 5 deletions lightllm/models/llava/llava_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import List, Union
from safetensors import safe_open
from io import BytesIO
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from lightllm.utils.log_utils import init_logger

Expand Down Expand Up @@ -123,16 +124,16 @@ def forward(self, x):
x = x.view(B, L, -1)
return x

def encode(self, image_uuids: List):
def encode(self, images: List[ImageItem]):
img_tensors = []
uuids = []
valid_id = 0
valid_ids = []

for i, item in enumerate(image_uuids):
if isinstance(item, int):
uuids.append(item)
image_data = read_shm(get_shm_name_data(item))
for i, img in enumerate(images):
if isinstance(img, ImageItem):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data)).convert("RGB")
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
img_tensors.append(t)
Expand Down
2 changes: 2 additions & 0 deletions lightllm/models/llava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(self, tokenizer, model_cfg):
self.image_length = (image_size // patch_size) ** 2
self.skip_start = model_cfg.get("skip_start", True)

def init_imageItem_extral_params(self, img: ImageItem, num_images):
return
def get_image_token_length(self, img: ImageItem):
return self.image_length

Expand Down
3 changes: 3 additions & 0 deletions lightllm/models/qwen2_vl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def __init__(self, tokenizer=None, image_processor=None, **kwargs):
self.image_end_id = kwargs["model_cfg"]["vision_end_token_id"]
self.image_token_id = kwargs["model_cfg"]["image_token_id"]

def init_imageItem_extral_params(self, img: ImageItem, num_images):
return

def get_image_token_length(self, img: ImageItem):
width = img.image_w
height = img.image_h
Expand Down
11 changes: 6 additions & 5 deletions lightllm/models/qwen2_vl/qwen2_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from transformers import AutoProcessor
from safetensors import safe_open
from transformers.utils import TensorType
from lightllm.server.multimodal_params import MultimodalParams,ImageItem
from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor


Expand Down Expand Up @@ -425,17 +426,17 @@ def load_model(self, weight_dir):

self.load_state_dict(weight_dict)

def encode(self, image_uuids: List):
def encode(self, images: List[ImageItem]):
img_tensors = []
valid_ids = []
valid_id = 0
img_grids = []
uuids = []

for i, url in enumerate(image_uuids):
if isinstance(url, int):
uuids.append(url)
image_data = read_shm(get_shm_name_data(url))
for i, img in enumerate(images):
if isinstance(img, ImageItem):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data))
image_data = get_image(image_data)
image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt")
Expand Down
3 changes: 3 additions & 0 deletions lightllm/models/qwen_vl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def __init__(self, tokenizer, model_cfg):
# <imgpad>: 151859
self.image_length = model_cfg["visual"].get("n_queries", 256)

def init_imageItem_extral_params(self, img: ImageItem, num_images):
return

def _list_find(self, input_list, target, start_idx):
cur_list = input_list[start_idx:]
if target in cur_list:
Expand Down
4 changes: 3 additions & 1 deletion lightllm/models/vit/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.distributed as dist
from lightllm.models.vit.layer_weights.pre_and_post_layer_weight import ViTPreAndPostLayerWeight
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd


class ViTPostLayerInfer:
Expand Down Expand Up @@ -44,7 +45,8 @@ def forward(self, vit_embeds, layer_weight: ViTPreAndPostLayerWeight):
layer_weight.mlp1_1_bias_, vit_embeds_norm.view(-1, vit_embeds_norm.shape[-1]), layer_weight.mlp1_1_weight_
)

vit_embeds_gelu = torch.nn.functional.gelu(vit_embeds_1)
# vit_embeds_gelu = torch.nn.functional.gelu(vit_embeds_1)
vit_embeds_gelu = gelu_fwd(vit_embeds_1)

vit_embeds_out = torch.addmm(
layer_weight.mlp1_3_bias_,
Expand Down
9 changes: 6 additions & 3 deletions lightllm/models/vit/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward, torch_rms_norm
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm


class ViTTransformerLayerInfer:
Expand Down Expand Up @@ -58,7 +60,7 @@ def tp_norm(self, input, weight):

def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
if layer_weight.norm_type == "rms_norm":
b = rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_)
b = rms_norm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_)
else:
b = torch.nn.functional.layer_norm(
input,
Expand All @@ -71,7 +73,7 @@ def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten

def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
if layer_weight.norm_type == "rms_norm":
return rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_)
return rms_norm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_)
else:
return torch.nn.functional.layer_norm(
input,
Expand Down Expand Up @@ -113,7 +115,8 @@ def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor

def _ffn(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
fc1 = layer_weight.ffn_1_proj_.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False)
ffn1_out = torch.nn.functional.gelu(fc1)
# ffn1_out = torch.nn.functional.gelu(fc1)
ffn1_out = gelu_fwd(fc1)
input_shape = input.shape
input = None
ffn2_out = layer_weight.ffn_2_proj_.mm(ffn1_out, use_custom_tensor_mananger=False)
Expand Down
16 changes: 8 additions & 8 deletions lightllm/models/vit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lightllm.models.vit.layer_weights.pre_and_post_layer_weight import ViTPreAndPostLayerWeight
from lightllm.models.vit.layer_weights.transformer_layer_weight import ViTTransformerLayerWeight
from lightllm.models.vit.layer_weights.hf_load_utils import load_hf_weights
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
from lightllm.common.build_utils import repair_config
from lightllm.utils.log_utils import init_logger
from lightllm.models.vit import get_load_image_func
Expand Down Expand Up @@ -135,20 +136,20 @@ def forward(self, pixel_values):
return input_embs

@torch.no_grad()
def encode(self, image_uuids: List):
def encode(self, images: List[ImageItem]):
img_tensors = []
valid_ids = []
valid_id = 0
uuids = []
for i, url in enumerate(image_uuids):
if isinstance(url, int):
uuids.append(url)
image_data = read_shm(get_shm_name_data(url))
for i, img in enumerate(images):
if isinstance(img, ImageItem):
uuids.append(img.uuid)
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data))
t = self.load_image_func(image_data)
t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"])
img_tensors.append(t)
else:
raise Exception("Unsupport input types: {} for {}".format(type(url), url))
raise Exception("Unsupport input types: {} for {}".format(type(img), img))

cur_num = img_tensors[-1].shape[0]
valid_ids.append([valid_id, valid_id + cur_num])
Expand All @@ -159,7 +160,6 @@ def encode(self, image_uuids: List):

imgs = torch.cat(img_tensors, dim=0)
pixel_values = imgs.cuda().to(dtype=self.data_type)
print(pixel_values.shape, pixel_values.dtype)
all_img_embeds = self.forward(pixel_values)
return all_img_embeds, uuids, valid_ids

Expand Down
Loading
Loading