Skip to content

Commit bea6d68

Browse files
author
sangchengmeng
committed
qwen2vl-fix
1 parent d45574e commit bea6d68

File tree

4 files changed

+151
-6
lines changed

4 files changed

+151
-6
lines changed

lightllm/models/llama/model.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ def _init_custom(self):
8888
and self.config.get("rope_scaling", {}).get("rope_type", "base") == "llama3"
8989
):
9090
self._init_to_get_llama3_rotary()
91+
elif (
92+
self.config.get("rope_scaling", None) is not None
93+
and self.config.get("rope_scaling", {}).get("type", "base") == "mrope"
94+
):
95+
self._init_to_get_mrope_rotary()
9196
else:
9297
self._init_to_get_rotary()
9398
return
@@ -332,3 +337,47 @@ def _init_to_get_llama3_rotary(self, default_base=10000):
332337
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
333338
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
334339
return
340+
341+
def _init_to_get_mrope_rotary(self, default_base=10000):
342+
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
343+
if self.config.get("rope_scaling", {}) is None:
344+
rope_scaling_factor = 1.0
345+
else:
346+
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
347+
348+
base = self.config.get("rope_theta", float(default_base))
349+
350+
if "max_sequence_length" in self.config:
351+
max_seq_len = self.config["max_sequence_length"]
352+
else:
353+
max_position_embeddings = self.config.get(
354+
"max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384
355+
)
356+
max_seq_len = max_position_embeddings * rope_scaling_factor
357+
358+
# NTK
359+
try:
360+
ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
361+
assert ntk_alpha >= 1
362+
if ntk_alpha > 1:
363+
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
364+
max_seq_len *= ntk_alpha
365+
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
366+
except:
367+
pass
368+
369+
inv_freq = 1.0 / (
370+
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
371+
)
372+
373+
t = (
374+
torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32)
375+
/ rope_scaling_factor
376+
)
377+
freqs = torch.outer(t, inv_freq).unsqueeze(0).expand(3, -1, -1)
378+
freqs = torch.cat((freqs, freqs), dim=-1)
379+
380+
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
381+
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
382+
383+
return
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
import numpy as np
3+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
4+
5+
6+
class Qwen2VLInferStateInfo(LlamaInferStateInfo):
7+
def __init__(self):
8+
super().__init__()
9+
self.position_cos = None
10+
self.position_sin = None
11+
12+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
13+
if self.is_prefill:
14+
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
15+
self.max_seq_len = b_seq_len_numpy.max()
16+
position_ids = torch.from_numpy(
17+
np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))])
18+
).cuda()
19+
self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1)
20+
self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1)
21+
position_ids = None
22+
else:
23+
position_ids = self.b_seq_len - 1
24+
self.position_sin = model._sin_cached[:, position_ids, :].unsqueeze(1)
25+
self.position_cos = model._cos_cached[:, position_ids, :].unsqueeze(1)
26+
return
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import torch
2+
import torch.functional as F
3+
import torch.distributed as dist
4+
import numpy as np
5+
6+
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
7+
8+
import torch.nn as nn
9+
from functools import partial
10+
11+
12+
def rotate_half(x):
13+
x1 = x[..., : x.shape[-1] // 2]
14+
x2 = x[..., x.shape[-1] // 2 :]
15+
return torch.cat((-x2, x1), dim=-1)
16+
17+
18+
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
19+
mrope_section = mrope_section * 2
20+
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
21+
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
22+
23+
q_embed = (q * cos) + (rotate_half(q) * sin)
24+
k_embed = (k * cos) + (rotate_half(k) * sin)
25+
26+
return q_embed, k_embed
27+
28+
29+
class Qwen2RMSNorm(nn.Module):
30+
def __init__(self, hidden_size, device, eps=1e-6):
31+
super().__init__()
32+
self.variance_epsilon = eps
33+
34+
def forward(self, hidden_states, weight):
35+
input_dtype = hidden_states.dtype
36+
hidden_states = hidden_states.to(torch.float32)
37+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
38+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
39+
40+
return (weight * hidden_states).to(input_dtype)
41+
42+
43+
class Qwen2VLTransformerLayerInfer(LlamaTransformerLayerInfer):
44+
def __init__(self, layer_num, network_config, mode=[]):
45+
super().__init__(layer_num, network_config, mode)
46+
self.mrope_section = network_config["rope_scaling"]["mrope_section"]
47+
self.norm_fwd = Qwen2RMSNorm(
48+
network_config["hidden_size"], device="cuda", eps=network_config.get("rms_norm_eps", 1e-06)
49+
)
50+
51+
def _bind_norm(self):
52+
self._ffn_norm = partial(LlamaTransformerLayerInfer._ffn_norm, self)
53+
54+
def _att_norm(self, input_embedding, infer_state, layer_weight) -> torch.Tensor:
55+
return self.norm_fwd(input_embedding, weight=layer_weight.att_norm_weight_.weight)
56+
57+
def _get_qkv(self, input, cache_kv, infer_state, layer_weight):
58+
q = layer_weight.q_proj.mm(input)
59+
cache_kv = layer_weight.kv_proj.mm(
60+
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
61+
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
62+
seq_len, _ = q.shape
63+
q = q.view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
64+
k = cache_kv[:, : self.tp_k_head_num_, :].view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
65+
new_q, new_k = apply_multimodal_rotary_pos_emb(
66+
q, k, infer_state.position_cos, infer_state.position_sin, self.mrope_section
67+
)
68+
new_q = new_q.transpose(1, 2).reshape(1, seq_len, -1)
69+
cache_kv[:, : self.tp_k_head_num_, :] = new_k.squeeze(0).permute(1, 0, 2)
70+
71+
return new_q, cache_kv

lightllm/models/qwen2_vl/model.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,16 @@
1212
from typing import List, Optional, Union
1313
from transformers.utils import TensorType, logging
1414
from lightllm.common.build_utils import repair_config
15+
from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo
16+
from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer
1517

16-
# from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor
1718
import torch
1819
from PIL import Image
1920
from .vision_process import smart_resize
2021
from lightllm.models.qwen2.layer_weights import transformer_layer_weight, pre_and_post_layer_weight
2122
from lightllm.models.qwen2.model import Qwen2TpPartModel
2223
import os
2324

24-
# from lightllm.models.qwen2_vl.layer_weight.pre_and_post_layer_weight import Qwen2VLPreAndPostLayerWeight
25-
2625
# Warp of the origal tokenizer
2726
class QWen2VLTokenizer:
2827
def __init__(self, tokenizer=None, image_processor=None, **kwargs):
@@ -89,10 +88,10 @@ def __getattr__(self, name):
8988

9089
class Qwen2VLTpPartModel(Qwen2TpPartModel):
9190

92-
# weight class
93-
# pre_and_post_weight_class = Qwen2VLPreAndPostLayerWeight
94-
# infer class
9591
pre_layer_infer_class = LlamaMultimodalPreLayerInfer
92+
transformer_layer_infer_class = Qwen2VLTransformerLayerInfer
93+
94+
infer_state_class = Qwen2VLInferStateInfo
9695

9796
def __init__(self, kvargs):
9897
super().__init__(kvargs)

0 commit comments

Comments
 (0)