Skip to content

Commit 97e60fb

Browse files
committed
0808-opti-qwen2-vl-vit-2
1 parent 9a95141 commit 97e60fb

File tree

1 file changed

+133
-69
lines changed

1 file changed

+133
-69
lines changed

lightllm/models/qwen2_5_vl/qwen2_5_visual.py

Lines changed: 133 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,48 @@
2424
from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding
2525
from lightllm.models.vit.triton_kernel.flashattention_nopad import flash_attention_fwd
2626
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
27-
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
2827

2928
# adapted from
3029
# https://github.com/huggingface/transformers/blob/
3130
# be37d34f44ff1bc928e59ffb8a30adecab8835a8/src
3231
# /transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py#L30C1-L31C1
32+
class Qwen2_5_VLVisionConfig(PretrainedConfig):
33+
model_type = "qwen2_5_vl"
34+
35+
def __init__(
36+
self,
37+
depth=32,
38+
hidden_size=3584,
39+
hidden_act="silu",
40+
intermediate_size=3420,
41+
num_heads=16,
42+
in_channels=3,
43+
patch_size=14,
44+
spatial_merge_size=2,
45+
temporal_patch_size=2,
46+
tokens_per_second=4,
47+
window_size=112,
48+
out_hidden_size=3584,
49+
fullatt_block_indexes=[7, 15, 23, 31],
50+
**kwargs,
51+
):
52+
super().__init__(**kwargs)
53+
54+
self.depth = depth
55+
self.hidden_size = hidden_size
56+
self.hidden_act = hidden_act
57+
self.intermediate_size = intermediate_size
58+
self.num_heads = num_heads
59+
self.in_channels = in_channels
60+
self.patch_size = patch_size
61+
self.spatial_merge_size = spatial_merge_size
62+
self.temporal_patch_size = temporal_patch_size
63+
self.tokens_per_second = tokens_per_second
64+
self.window_size = window_size
65+
self.fullatt_block_indexes = fullatt_block_indexes
66+
self.out_hidden_size = out_hidden_size
67+
68+
3369
class Qwen2RMSNorm(nn.Module):
3470
def __init__(self, hidden_size, eps=1e-6):
3571
"""
@@ -68,6 +104,27 @@ def forward(self, hidden_state):
68104
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
69105

70106

107+
def rotate_half(x):
108+
"""Rotates half the hidden dims of the input."""
109+
x1 = x[..., : x.shape[-1] // 2]
110+
x2 = x[..., x.shape[-1] // 2 :]
111+
return torch.cat((-x2, x1), dim=-1)
112+
113+
114+
def apply_rotary_pos_emb_vision(
115+
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
116+
) -> Tuple[torch.Tensor, torch.Tensor]:
117+
orig_q_dtype = q.dtype
118+
orig_k_dtype = k.dtype
119+
q, k = q.float(), k.float()
120+
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
121+
q_embed = (q * cos) + (rotate_half(q) * sin)
122+
k_embed = (k * cos) + (rotate_half(k) * sin)
123+
q_embed = q_embed.to(orig_q_dtype)
124+
k_embed = k_embed.to(orig_k_dtype)
125+
return q_embed, k_embed
126+
127+
71128
class Qwen2_5_VLVisionFlashAttention(nn.Module):
72129
def __init__(self, dim: int, num_heads: int = 16) -> None:
73130
super().__init__()
@@ -76,27 +133,26 @@ def __init__(self, dim: int, num_heads: int = 16) -> None:
76133
self.qkv = nn.Linear(dim, dim * 3, bias=True)
77134
self.proj = nn.Linear(dim, dim)
78135

79-
def apply_rotary_pos_emb_vision(self, t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
80-
t_ = t.float()
81-
cos = freqs.cos()
82-
sin = freqs.sin()
83-
output = apply_rotary_emb(t_, cos, sin).type_as(t)
84-
return output
85-
86136
def forward(
87137
self,
88138
hidden_states: torch.Tensor,
89139
cu_seqlens: torch.Tensor,
90-
max_seqlen: int = 0,
91140
rotary_pos_emb: Optional[torch.Tensor] = None,
141+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
92142
) -> torch.Tensor:
93143
seq_length = hidden_states.shape[0]
94144
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
95-
q = self.apply_rotary_pos_emb_vision(q, rotary_pos_emb)
96-
k = self.apply_rotary_pos_emb_vision(k, rotary_pos_emb)
145+
if position_embeddings is None:
146+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
147+
cos = emb.cos()
148+
sin = emb.sin()
149+
else:
150+
cos, sin = position_embeddings
151+
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
97152

153+
cu_seqlens = cu_seqlens.to(q.device, torch.int32)
154+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
98155
attn_output = g_cache_manager.alloc_tensor(q.shape, q.dtype, device=q.device)
99-
100156
flash_attention_fwd(q, k, v, attn_output, cu_seqlens, max_seqlen)
101157
attn_output = attn_output.reshape(seq_length, -1)
102158
attn_output = self.proj(attn_output)
@@ -127,14 +183,12 @@ def forward(
127183
self,
128184
hidden_states: torch.Tensor,
129185
cu_seqlens: torch.Tensor,
130-
max_seqlen: int = 0,
131186
rotary_pos_emb: Optional[torch.Tensor] = None,
132187
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
133188
) -> torch.Tensor:
134189
hidden_states = hidden_states + self.attn(
135190
self.norm1(hidden_states),
136191
cu_seqlens=cu_seqlens,
137-
max_seqlen=max_seqlen,
138192
rotary_pos_emb=rotary_pos_emb,
139193
position_embeddings=position_embeddings,
140194
)
@@ -178,7 +232,6 @@ def __init__(
178232
**kwargs,
179233
):
180234
super().__init__()
181-
self.data_type = kvargs.get("data_type", "bfloat16")
182235

183236
self.depth = depth
184237
self.hidden_size = hidden_size
@@ -204,7 +257,7 @@ def __init__(
204257
)
205258

206259
head_dim = self.hidden_size // self.num_heads
207-
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2).to("cuda", non_blocking=True)
260+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2).to("cuda", dtype=self.get_dtype(), non_blocking=True)
208261

209262
self.blocks = nn.ModuleList(
210263
[
@@ -226,62 +279,41 @@ def __init__(
226279

227280
self.gradient_checkpointing = False
228281

229-
self._init_datatype()
230-
231-
def _init_datatype(self):
232-
if isinstance(self.data_type, torch.dtype):
233-
return
234-
if self.data_type in ["fp16", "float16"]:
235-
self.data_type = torch.float16
236-
elif self.data_type in ["bf16", "bfloat16"]:
237-
self.data_type = torch.bfloat16
238-
elif self.data_type in ["fp32", "float32"]:
239-
self.data_type = torch.float32
240-
else:
241-
raise ValueError(f"Unsupport datatype {self.data_type}!")
242-
return
282+
self.device = self.get_device()
283+
self.dtype = self.get_dtype()
243284

244-
def load_model(self, weight_dir):
285+
def get_dtype(self) -> torch.dtype:
286+
return self.blocks[0].mlp.down_proj.weight.dtype
245287

246-
processor_config_path = os.path.join(weight_dir, "preprocessor_config.json")
247-
with open(processor_config_path, "r") as f:
248-
processor_config_dict = json.load(f)
249-
self.processor = Qwen2VLImageProcessor(**processor_config_dict)
250-
251-
bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
252-
if bin_weight_files:
253-
weight_dict = {}
254-
for file_ in bin_weight_files:
255-
f = torch.load(os.path.join(weight_dir, file_), "cpu")
256-
for k, v in f.items():
257-
if "visual" in k:
258-
weight_dict[k[len("visual.") :]] = v
259-
260-
else:
261-
hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")]
262-
weight_dict = {}
263-
for file_ in hf_weight_files:
264-
f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu")
265-
for k in f.keys():
266-
if "visual" in k:
267-
weight_dict[k[len("visual.") :]] = f.get_tensor(k)
268-
269-
self.load_state_dict(weight_dict)
288+
def get_device(self) -> torch.device:
289+
return self.blocks[0].mlp.down_proj.weight.device
270290

271291
def rot_pos_emb(self, grid_thw):
272292
pos_ids = []
273-
s = self.spatial_merge_size
274-
for _, h, w in grid_thw:
275-
pos_shape = (h // s, s, w // s, s)
293+
for t, h, w in grid_thw:
276294
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
277-
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
278-
hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten()
279-
wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten()
295+
hpos_ids = hpos_ids.reshape(
296+
h // self.spatial_merge_size,
297+
self.spatial_merge_size,
298+
w // self.spatial_merge_size,
299+
self.spatial_merge_size,
300+
)
301+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
302+
hpos_ids = hpos_ids.flatten()
280303

281-
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1))
304+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
305+
wpos_ids = wpos_ids.reshape(
306+
h // self.spatial_merge_size,
307+
self.spatial_merge_size,
308+
w // self.spatial_merge_size,
309+
self.spatial_merge_size,
310+
)
311+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
312+
wpos_ids = wpos_ids.flatten()
313+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
282314
pos_ids = torch.cat(pos_ids, dim=0)
283315
max_grid_size = grid_thw[:, 1:].max()
284-
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size).type(torch.float32)
316+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
285317
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
286318
return rotary_pos_emb
287319

@@ -328,7 +360,7 @@ def get_window_index(self, grid_thw):
328360

329361
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
330362
hidden_states = self.patch_embed(hidden_states)
331-
rotary_pos_emb = self.rot_pos_emb(grid_thw).to("cuda", non_blocking=True)
363+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
332364
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
333365
cu_window_seqlens = torch.tensor(
334366
cu_window_seqlens,
@@ -344,14 +376,20 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
344376
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
345377
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
346378
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
347-
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
379+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1).to("cuda", dtype=self.get_dtype(), non_blocking=True)
380+
position_embeddings = (emb.cos(), emb.sin())
348381

349382
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
350383
dim=0,
384+
# Select dtype based on the following factors:
385+
# - FA2 requires that cu_seqlens_q must have dtype int32
386+
# - torch.onnx.export requires that cu_seqlens_q must have same
387+
# dtype as grid_thw
388+
# See https://github.com/huggingface/transformers/pull/34852
389+
# for more information
351390
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
352391
)
353392
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
354-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
355393

356394
for layer_num, blk in enumerate(self.blocks):
357395
if layer_num in self.fullatt_block_indexes:
@@ -361,8 +399,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
361399
hidden_states = blk(
362400
hidden_states,
363401
cu_seqlens=cu_seqlens_now,
364-
max_seqlen=max_seqlen,
365-
position_embeddings=emb,
402+
position_embeddings=position_embeddings,
366403
)
367404

368405
hidden_states = self.merger(hidden_states)
@@ -371,6 +408,33 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
371408

372409
return hidden_states
373410

411+
def load_model(self, weight_dir):
412+
413+
processor_config_path = os.path.join(weight_dir, "preprocessor_config.json")
414+
with open(processor_config_path, "r") as f:
415+
processor_config_dict = json.load(f)
416+
self.processor = Qwen2VLImageProcessor(**processor_config_dict)
417+
418+
bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
419+
if bin_weight_files:
420+
weight_dict = {}
421+
for file_ in bin_weight_files:
422+
f = torch.load(os.path.join(weight_dir, file_), "cpu")
423+
for k, v in f.items():
424+
if "visual" in k:
425+
weight_dict[k[len("visual.") :]] = v
426+
427+
else:
428+
hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")]
429+
weight_dict = {}
430+
for file_ in hf_weight_files:
431+
f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu")
432+
for k in f.keys():
433+
if "visual" in k:
434+
weight_dict[k[len("visual.") :]] = f.get_tensor(k)
435+
436+
self.load_state_dict(weight_dict)
437+
374438
def encode(self, images: List[ImageItem]):
375439
img_tensors = []
376440
valid_ids = []
@@ -402,7 +466,7 @@ def encode(self, images: List[ImageItem]):
402466
imgs = torch.cat(img_tensors, dim=0)
403467
grid_thw = torch.cat(img_grids, dim=0)
404468

405-
pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True)
469+
pixel_values = imgs.to("cuda", dtype=self.get_dtype(), non_blocking=True)
406470
image_grid_thw = grid_thw.to("cuda", non_blocking=True)
407471

408472
all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw)

0 commit comments

Comments
 (0)