Skip to content
Merged
99 changes: 39 additions & 60 deletions lightllm/models/internvl/img_process.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,56 @@
import torch
import torch.nn.functional as F
from PIL import Image
import math
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
def find_closest_aspect_ratio(width, height, min_num=1, max_num=6, image_size=448):
"""
Find the closest aspect ratio from a list of target ratios to match the given aspect ratio.
If the difference is the same, use the area to decide the better ratio.
"""
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
# Compare areas to decide the better ratio when the difference is the same
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
assert min_num == 1
log_ratio = math.log(width / height)
ratio = width * height / (image_size * image_size)
multiple = min(math.ceil(ratio), max_num)
if multiple <= 1:
return [1, 1]
candidate_split_grids_nums = []
for i in [multiple - 1, multiple, multiple + 1]:
if i > max_num:
continue
candidate_split_grids_nums.append(i)

candidate_grids = []
for split_grids_nums in candidate_split_grids_nums:
m = 1
while m <= split_grids_nums:
if split_grids_nums % m == 0:
candidate_grids.append([m, split_grids_nums // m])
m += 1
best_grid = [1, 1]
min_error = float("inf")
for grid in candidate_grids:
error = abs(log_ratio - math.log(grid[0] / grid[1]))
if error < min_error:
best_grid = grid
min_error = error

return best_grid


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=True):
"""
Preprocess the image dynamically by finding the closest aspect ratio,
resizing the image, and splitting it into smaller blocks.
Optionally add a thumbnail version of the image.
"""
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height

# Calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

# Find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)

# Calculate the target width and height
original_width, original_height = image.size
target_aspect_ratio = find_closest_aspect_ratio(original_width, original_height, min_num, max_num, image_size)
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

# Resize the image to the target dimensions
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
Expand All @@ -63,40 +60,22 @@ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnai
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# Split the image into blocks
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)

assert len(processed_images) == blocks

# Optionally add a thumbnail version of the image
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)

return processed_images


def get_image_patch(orign_width, orign_height, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
def get_image_patch(orign_width, orign_height, min_num=1, max_num=6, image_size=448, use_thumbnail=True):
"""
Calculate the number of image patches based on the closest aspect ratio
and the given width and height of the original image.
"""
aspect_ratio = orign_width / orign_height

# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orign_width, orign_height, image_size)

target_aspect_ratio = find_closest_aspect_ratio(orign_width, orign_height, min_num, max_num, image_size)
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
if use_thumbnail and blocks != 1:
blocks += 1
Expand Down
6 changes: 4 additions & 2 deletions lightllm/models/internvl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag)
self.get_image_patch_func = get_image_patch_func(kwargs["weight_dir"])

def get_image_token_length(self, img: ImageItem):
return self.get_image_patch_func(img.image_w, img.image_h, use_thumbnail=True) * self.image_length
def get_image_token_length(self, img: ImageItem, max_num):
return (
self.get_image_patch_func(img.image_w, img.image_h, max_num=max_num, use_thumbnail=True) * self.image_length
)

# only change the impl of the encode func:
def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
Expand Down
7 changes: 4 additions & 3 deletions lightllm/models/vit/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.distributed as dist
from lightllm.models.vit.layer_weights.pre_and_post_layer_weight import ViTPreAndPostLayerWeight
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size

from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd

class ViTPostLayerInfer:
""" """
Expand Down Expand Up @@ -44,8 +44,9 @@ def forward(self, vit_embeds, layer_weight: ViTPreAndPostLayerWeight):
layer_weight.mlp1_1_bias_, vit_embeds_norm.view(-1, vit_embeds_norm.shape[-1]), layer_weight.mlp1_1_weight_
)

vit_embeds_gelu = torch.nn.functional.gelu(vit_embeds_1)

# vit_embeds_gelu = torch.nn.functional.gelu(vit_embeds_1)
vit_embeds_gelu = gelu_fwd(vit_embeds_1)

vit_embeds_out = torch.addmm(
layer_weight.mlp1_3_bias_,
vit_embeds_gelu.view(-1, self.llm_hidden_size // self.tp_world_size_),
Expand Down
10 changes: 6 additions & 4 deletions lightllm/models/vit/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward, torch_rms_norm
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size

from lightllm.models.vit.triton_kernel.gelu_vit import gelu_fwd
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm

class ViTTransformerLayerInfer:
""" """
Expand Down Expand Up @@ -58,7 +59,7 @@ def tp_norm(self, input, weight):

def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
if layer_weight.norm_type == "rms_norm":
b = rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_)
b = rms_norm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_)
else:
b = torch.nn.functional.layer_norm(
input,
Expand All @@ -71,7 +72,7 @@ def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten

def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
if layer_weight.norm_type == "rms_norm":
return rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_)
return rms_norm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_)
else:
return torch.nn.functional.layer_norm(
input,
Expand Down Expand Up @@ -113,7 +114,8 @@ def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor

def _ffn(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor:
fc1 = layer_weight.ffn_1_proj_.mm(input.view(-1, self.embed_dim_), use_custom_tensor_mananger=False)
ffn1_out = torch.nn.functional.gelu(fc1)
# ffn1_out = torch.nn.functional.gelu(fc1)
ffn1_out = gelu_fwd(fc1)
input_shape = input.shape
input = None
ffn2_out = layer_weight.ffn_2_proj_.mm(ffn1_out, use_custom_tensor_mananger=False)
Expand Down
5 changes: 3 additions & 2 deletions lightllm/models/vit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def forward(self, pixel_values):
return input_embs

@torch.no_grad()
def encode(self, image_uuids: List):
def encode(self, image_uuids: List, max_num_list: List):
img_tensors = []
valid_ids = []
valid_id = 0
Expand All @@ -145,7 +145,8 @@ def encode(self, image_uuids: List):
uuids.append(url)
image_data = read_shm(get_shm_name_data(url))
image_data = Image.open(BytesIO(image_data))
t = self.load_image_func(image_data)
max_num = max_num_list[i]
t = self.load_image_func(image_data, max_num=max_num)
img_tensors.append(t)
else:
raise Exception("Unsupport input types: {} for {}".format(type(url), url))
Expand Down
55 changes: 53 additions & 2 deletions lightllm/models/vit/triton_kernel/flashattention_nopad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional as F

TESLA = "Tesla" in torch.cuda.get_device_name(0)
HOPPER = "H100" in torch.cuda.get_device_name(0) or "H800" in torch.cuda.get_device_name(0) or "Hopper" in torch.cuda.get_device_name(0)


if triton.__version__ >= "2.1.0":
Expand Down Expand Up @@ -101,7 +102,7 @@ def _fwd_kernel(
return

@torch.no_grad()
def flash_attention_fwd(
def _flash_attention_triton_fwd(
q,
k,
v,
Expand Down Expand Up @@ -149,6 +150,56 @@ def flash_attention_fwd(
else:
raise Exception("error triton version!")

_flash_attn_v3_available = False
try:
from flash_attn_interface import _flash_attn_forward

_flash_attn_v3_available = True
def flash_attention_v3_fwd(
q,
k,
v,
o,
):
head_dim = q.shape[-1]
softmax_scale = head_dim ** -0.5
_flash_attn_forward(
q,
k,
v,
None, None, # k_new, v_new
None, # qv
o, # out
None, None, None, # cu_seqlens_q/k/k_new
None, None, # seqused_q/k
None, None, # max_seqlen_q/k
None, None, None, # page_table, kv_batch_idx, leftpad_k,
None, None, # rotary_cos/sin
None, None, None,
softmax_scale,
causal=False,
window_size=(-1, -1),
softcap=0.0,
num_splits=1,
pack_gqa=None,
sm_margin=0,
)
return
except ImportError:
print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.")


def flash_attention_fwd(q, k, v, o):
"""
统一的 Flash Attention 接口。如果 _flash_attn_forward 存在,
则使用 flash_attention_v3_fwd,否则使用 Triton 版本。
"""
if _flash_attn_v3_available and HOPPER :
flash_attention_v3_fwd(q, k, v, o)
else:
_flash_attention_triton_fwd(q, k, v, o)



def torch_att(q, k, v):
head_dim = q.shape[-1]
Expand Down Expand Up @@ -188,4 +239,4 @@ def test():

print("max ", torch.max(torch.abs(torch_out - o)))
print("mean ", torch.mean(torch.abs(torch_out - o)))
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
27 changes: 27 additions & 0 deletions lightllm/models/vit/triton_kernel/gelu_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
import triton
import triton.language as tl


@triton.jit
def gelu(x):
x_fp32 = x.to(tl.float32)
x_gelu = 0.5 * x_fp32 * (1 + tl.math.erf(x_fp32 * 0.7071067811))
return x_gelu

@triton.jit
def gelu_kernel(output_ptr, input_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
input = tl.load(input_ptr + offsets, mask=mask)
output = gelu(input)
tl.store(output_ptr + offsets, output, mask=mask)

def gelu_fwd(input):
output = torch.empty_like(input)
n_elements = input.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
gelu_kernel[grid](output, input, n_elements, BLOCK_SIZE=1024)
return output
57 changes: 57 additions & 0 deletions lightllm/models/vit/triton_kernel/rms_norm_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import triton
import triton.language as tl
from torch import Tensor


@triton.jit
def rms_norm_kernel(
input,
weight,
output,
input_row_stride: tl.constexpr,
eps: tl.constexpr,
N_COLS: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Rms norm kernel."""
prog_id = tl.program_id(0)
offsets = tl.arange(0, BLOCK_N)

w = tl.load(weight + offsets, mask=offsets < N_COLS)

x_ptr = input + prog_id * input_row_stride
x = tl.load(x_ptr + offsets, mask=offsets < N_COLS)
xf = x.to(tl.float32)

var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
out = xf / tl.sqrt(var + eps)
out = (w * out).to(x.dtype)

out_ptr = output + prog_id * input_row_stride
tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)


def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-5):
"""Rms norm."""
feat_size = weight.shape[0]
seq_len = hidden_states.numel() // hidden_states.size(-1)
input_stride = hidden_states.stride(-2)

BLOCK_N = triton.next_power_of_2(feat_size)
out = torch.empty_like(hidden_states)

grid = (seq_len,)
rms_norm_kernel[grid](
hidden_states,
weight,
out,
input_row_stride=input_stride,
eps=eps,
N_COLS=feat_size,
BLOCK_N=BLOCK_N,
num_warps=4,
num_stages=3,
)

return out
Loading
Loading