Skip to content

Commit e7bee77

Browse files
committed
0812-add-rope-triton
1 parent 97e60fb commit e7bee77

File tree

2 files changed

+79
-164
lines changed

2 files changed

+79
-164
lines changed

lightllm/models/qwen2_5_vl/qwen2_5_visual.py

Lines changed: 74 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,21 @@
11
import os
2-
import re
32
import json
43
import torch
54
import torch.nn.functional as F
65
from PIL import Image
7-
from typing import Any, Dict, List, Optional, Tuple, Union
8-
from torchvision import transforms as T
9-
from torchvision.transforms.functional import InterpolationMode
10-
from transformers import AutoModel, AutoTokenizer
6+
from typing import List, Optional
117
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
128
from io import BytesIO
13-
from transformers.configuration_utils import PretrainedConfig
14-
from transformers.modeling_utils import PreTrainedModel
159
import torch.nn as nn
16-
from torch.nn import LayerNorm
1710
from transformers.activations import ACT2FN
18-
import math
1911
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor
20-
from transformers import AutoProcessor
2112
from safetensors import safe_open
22-
from transformers.utils import TensorType
23-
from lightllm.server.multimodal_params import MultimodalParams, ImageItem
13+
from lightllm.server.multimodal_params import ImageItem
2414
from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding
15+
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
2516
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
2617
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
27-
28-
# adapted from
29-
# https://github.com/huggingface/transformers/blob/
30-
# be37d34f44ff1bc928e59ffb8a30adecab8835a8/src
31-
# /transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py#L30C1-L31C1
32-
class Qwen2_5_VLVisionConfig(PretrainedConfig):
33-
model_type = "qwen2_5_vl"
34-
35-
def __init__(
36-
self,
37-
depth=32,
38-
hidden_size=3584,
39-
hidden_act="silu",
40-
intermediate_size=3420,
41-
num_heads=16,
42-
in_channels=3,
43-
patch_size=14,
44-
spatial_merge_size=2,
45-
temporal_patch_size=2,
46-
tokens_per_second=4,
47-
window_size=112,
48-
out_hidden_size=3584,
49-
fullatt_block_indexes=[7, 15, 23, 31],
50-
**kwargs,
51-
):
52-
super().__init__(**kwargs)
53-
54-
self.depth = depth
55-
self.hidden_size = hidden_size
56-
self.hidden_act = hidden_act
57-
self.intermediate_size = intermediate_size
58-
self.num_heads = num_heads
59-
self.in_channels = in_channels
60-
self.patch_size = patch_size
61-
self.spatial_merge_size = spatial_merge_size
62-
self.temporal_patch_size = temporal_patch_size
63-
self.tokens_per_second = tokens_per_second
64-
self.window_size = window_size
65-
self.fullatt_block_indexes = fullatt_block_indexes
66-
self.out_hidden_size = out_hidden_size
18+
from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton
6719

6820

6921
class Qwen2RMSNorm(nn.Module):
@@ -76,11 +28,7 @@ def __init__(self, hidden_size, eps=1e-6):
7628
self.variance_epsilon = eps
7729

7830
def forward(self, hidden_states):
79-
input_dtype = hidden_states.dtype
80-
hidden_states = hidden_states.to(torch.float32)
81-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
82-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
83-
return self.weight * hidden_states.to(input_dtype)
31+
return rms_norm(hidden_states, self.weight, eps=self.variance_epsilon)
8432

8533
def extra_repr(self):
8634
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
@@ -104,27 +52,6 @@ def forward(self, hidden_state):
10452
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
10553

10654

107-
def rotate_half(x):
108-
"""Rotates half the hidden dims of the input."""
109-
x1 = x[..., : x.shape[-1] // 2]
110-
x2 = x[..., x.shape[-1] // 2 :]
111-
return torch.cat((-x2, x1), dim=-1)
112-
113-
114-
def apply_rotary_pos_emb_vision(
115-
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
116-
) -> Tuple[torch.Tensor, torch.Tensor]:
117-
orig_q_dtype = q.dtype
118-
orig_k_dtype = k.dtype
119-
q, k = q.float(), k.float()
120-
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
121-
q_embed = (q * cos) + (rotate_half(q) * sin)
122-
k_embed = (k * cos) + (rotate_half(k) * sin)
123-
q_embed = q_embed.to(orig_q_dtype)
124-
k_embed = k_embed.to(orig_k_dtype)
125-
return q_embed, k_embed
126-
127-
12855
class Qwen2_5_VLVisionFlashAttention(nn.Module):
12956
def __init__(self, dim: int, num_heads: int = 16) -> None:
13057
super().__init__()
@@ -137,21 +64,16 @@ def forward(
13764
self,
13865
hidden_states: torch.Tensor,
13966
cu_seqlens: torch.Tensor,
67+
max_seqlen: int = 0,
14068
rotary_pos_emb: Optional[torch.Tensor] = None,
141-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
14269
) -> torch.Tensor:
14370
seq_length = hidden_states.shape[0]
14471
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
145-
if position_embeddings is None:
146-
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
147-
cos = emb.cos()
148-
sin = emb.sin()
149-
else:
150-
cos, sin = position_embeddings
151-
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
72+
q = apply_rotary_pos_emb_triton(q.unsqueeze(0), rotary_pos_emb.cos(), rotary_pos_emb.sin())
73+
k = apply_rotary_pos_emb_triton(k.unsqueeze(0), rotary_pos_emb.cos(), rotary_pos_emb.sin())
74+
q = q.squeeze(0)
75+
k = k.squeeze(0)
15276

153-
cu_seqlens = cu_seqlens.to(q.device, torch.int32)
154-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
15577
attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
15678
flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen)
15779
attn_output = attn_output.reshape(seq_length, -1)
@@ -183,14 +105,14 @@ def forward(
183105
self,
184106
hidden_states: torch.Tensor,
185107
cu_seqlens: torch.Tensor,
108+
max_seqlen: int = 0,
186109
rotary_pos_emb: Optional[torch.Tensor] = None,
187-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
188110
) -> torch.Tensor:
189111
hidden_states = hidden_states + self.attn(
190112
self.norm1(hidden_states),
191113
cu_seqlens=cu_seqlens,
114+
max_seqlen=max_seqlen,
192115
rotary_pos_emb=rotary_pos_emb,
193-
position_embeddings=position_embeddings,
194116
)
195117
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
196118
return hidden_states
@@ -232,6 +154,8 @@ def __init__(
232154
**kwargs,
233155
):
234156
super().__init__()
157+
self.weight_dir = kvargs["weight_dir"]
158+
self.data_type = kvargs.get("data_type", "bfloat16")
235159

236160
self.depth = depth
237161
self.hidden_size = hidden_size
@@ -257,7 +181,7 @@ def __init__(
257181
)
258182

259183
head_dim = self.hidden_size // self.num_heads
260-
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2).to("cuda", dtype=self.get_dtype(), non_blocking=True)
184+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
261185

262186
self.blocks = nn.ModuleList(
263187
[
@@ -279,41 +203,42 @@ def __init__(
279203

280204
self.gradient_checkpointing = False
281205

282-
self.device = self.get_device()
283-
self.dtype = self.get_dtype()
284-
285-
def get_dtype(self) -> torch.dtype:
286-
return self.blocks[0].mlp.down_proj.weight.dtype
206+
processor_config_path = os.path.join(self.weight_dir, "preprocessor_config.json")
207+
with open(processor_config_path, "r") as f:
208+
processor_config_dict = json.load(f)
209+
self.processor = Qwen2VLImageProcessor(**processor_config_dict)
287210

288-
def get_device(self) -> torch.device:
289-
return self.blocks[0].mlp.down_proj.weight.device
211+
self._init_datatype()
212+
self.load_model(kvargs["weight_dir"])
213+
self.cuda()
214+
215+
def _init_datatype(self):
216+
if isinstance(self.data_type, torch.dtype):
217+
return
218+
if self.data_type in ["fp16", "float16"]:
219+
self.data_type = torch.float16
220+
elif self.data_type in ["bf16", "bfloat16"]:
221+
self.data_type = torch.bfloat16
222+
elif self.data_type in ["fp32", "float32"]:
223+
self.data_type = torch.float32
224+
else:
225+
raise ValueError(f"Unsupport datatype {self.data_type}!")
226+
return
290227

291228
def rot_pos_emb(self, grid_thw):
292229
pos_ids = []
293-
for t, h, w in grid_thw:
230+
s = self.spatial_merge_size
231+
for _, h, w in grid_thw:
232+
pos_shape = (h // s, s, w // s, s)
294233
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
295-
hpos_ids = hpos_ids.reshape(
296-
h // self.spatial_merge_size,
297-
self.spatial_merge_size,
298-
w // self.spatial_merge_size,
299-
self.spatial_merge_size,
300-
)
301-
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
302-
hpos_ids = hpos_ids.flatten()
303-
304234
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
305-
wpos_ids = wpos_ids.reshape(
306-
h // self.spatial_merge_size,
307-
self.spatial_merge_size,
308-
w // self.spatial_merge_size,
309-
self.spatial_merge_size,
310-
)
311-
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
312-
wpos_ids = wpos_ids.flatten()
313-
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
235+
hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten()
236+
wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten()
237+
238+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1))
314239
pos_ids = torch.cat(pos_ids, dim=0)
315240
max_grid_size = grid_thw[:, 1:].max()
316-
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
241+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).type(torch.float32)
317242
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
318243
return rotary_pos_emb
319244

@@ -360,14 +285,22 @@ def get_window_index(self, grid_thw):
360285

361286
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
362287
hidden_states = self.patch_embed(hidden_states)
363-
rotary_pos_emb = self.rot_pos_emb(grid_thw)
288+
rotary_pos_emb = self.rot_pos_emb(grid_thw).to("cuda", non_blocking=True)
289+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
290+
dim=0, dtype=torch.int32
291+
)
292+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
293+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
294+
cu_seqlens = cu_seqlens.to("cuda", non_blocking=True)
295+
364296
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
365297
cu_window_seqlens = torch.tensor(
366298
cu_window_seqlens,
367299
device=hidden_states.device,
368300
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
369301
)
370302
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
303+
max_window_seqlen = (cu_window_seqlens[1:] - cu_window_seqlens[:-1]).max().item()
371304

372305
seq_len, _ = hidden_states.size()
373306
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
@@ -376,30 +309,20 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
376309
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
377310
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
378311
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
379-
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1).to("cuda", dtype=self.get_dtype(), non_blocking=True)
380-
position_embeddings = (emb.cos(), emb.sin())
381-
382-
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
383-
dim=0,
384-
# Select dtype based on the following factors:
385-
# - FA2 requires that cu_seqlens_q must have dtype int32
386-
# - torch.onnx.export requires that cu_seqlens_q must have same
387-
# dtype as grid_thw
388-
# See https://github.com/huggingface/transformers/pull/34852
389-
# for more information
390-
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
391-
)
392-
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
393312

394313
for layer_num, blk in enumerate(self.blocks):
395314
if layer_num in self.fullatt_block_indexes:
396315
cu_seqlens_now = cu_seqlens
316+
max_seqlen_now = max_seqlen
397317
else:
398318
cu_seqlens_now = cu_window_seqlens
319+
max_seqlen_now = max_window_seqlen
320+
399321
hidden_states = blk(
400322
hidden_states,
401323
cu_seqlens=cu_seqlens_now,
402-
position_embeddings=position_embeddings,
324+
max_seqlen=max_seqlen_now,
325+
rotary_pos_emb=rotary_pos_emb,
403326
)
404327

405328
hidden_states = self.merger(hidden_states)
@@ -408,12 +331,23 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
408331

409332
return hidden_states
410333

411-
def load_model(self, weight_dir):
334+
def load_image(self, img: List[ImageItem]):
335+
pixel_values = None
336+
if isinstance(img, ImageItem):
337+
image_data = read_shm(get_shm_name_data(img.uuid))
338+
image_data = Image.open(BytesIO(image_data))
339+
image_data = resize_image(image_data)
340+
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
341+
elif isinstance(img, dict):
342+
image_data = read_shm(get_shm_name_data(img["uuid"]))
343+
image_data = Image.open(BytesIO(image_data))
344+
image_data = resize_image(image_data)
345+
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
346+
else:
347+
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
348+
return pixel_values.to(dtype=self.data_type), image_grid_thw
412349

413-
processor_config_path = os.path.join(weight_dir, "preprocessor_config.json")
414-
with open(processor_config_path, "r") as f:
415-
processor_config_dict = json.load(f)
416-
self.processor = Qwen2VLImageProcessor(**processor_config_dict)
350+
def load_model(self, weight_dir):
417351

418352
bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
419353
if bin_weight_files:
@@ -466,7 +400,7 @@ def encode(self, images: List[ImageItem]):
466400
imgs = torch.cat(img_tensors, dim=0)
467401
grid_thw = torch.cat(img_grids, dim=0)
468402

469-
pixel_values = imgs.to("cuda", dtype=self.get_dtype(), non_blocking=True)
403+
pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True)
470404
image_grid_thw = grid_thw.to("cuda", non_blocking=True)
471405

472406
all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw)

0 commit comments

Comments
 (0)