Skip to content

Commit fa0cb52

Browse files
SangChengCsangchengmengshihaobaiwangzaijunhiworldwzj
authored
Add qwen3 vl (#1095)
Co-authored-by: sangchengmeng <sangchengmeng@sensetime.com> Co-authored-by: shihaobai <1798930569@qq.com> Co-authored-by: wangzaijun <wangzaijun@sensetime.com> Co-authored-by: wangzaijun <wzjhelloworld@qq.com>
1 parent ef28098 commit fa0cb52

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+2195
-532
lines changed

lightllm/common/basemodel/infer_struct.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -122,27 +122,6 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
122122
attr_.copy_(attr_value, non_blocking=True)
123123
return
124124

125-
def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor):
126-
"""
127-
功能函数,用于标记在chuncked prefill的过程中,到底哪些多模态对象对应的token是需要参与计算的。
128-
因为分chunck的原因,并不是所有的多模态对象对应的token都需要参与计算。
129-
"""
130-
multi_objs = []
131-
for _, p in enumerate(self.multimodal_params):
132-
for obj in p["images"] + p["audios"]:
133-
multi_objs.append(obj)
134-
135-
if multi_objs:
136-
obj_start_ids = torch.tensor([e["token_id"] for e in multi_objs], dtype=torch.int64, device="cuda")
137-
obj_token_lens = torch.tensor([e["token_num"] for e in multi_objs], dtype=torch.int64, device="cuda")
138-
marks = mark_multimodal_obj(
139-
obj_start_token_ids=obj_start_ids, obj_token_lens=obj_token_lens, input_ids=input_ids
140-
)
141-
marks_array = marks.detach().cpu().numpy()
142-
for mark, obj in zip(marks_array, multi_objs):
143-
obj["_prefill_"] = mark > 0
144-
return
145-
146125
def prefill_dp_balance(self, input_ids: torch.Tensor):
147126
"""
148127
在prefill的时候, 对于处于 dp 模式下的时候,对输入的数据进行重新的调整和分配,降低各个dp处理数据量过于不一致的时候,导致

lightllm/common/basemodel/triton_kernel/multimodal_emb.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,22 @@
77
def _fwd_kernel(
88
Prompt_ids,
99
Text_weight_embs,
10-
Img_embs,
10+
Embed_cache,
1111
Out,
1212
Img_token_lens,
1313
Img_start_token_ids,
14-
Img_start_locs,
14+
Img_start_locs_in_cache,
1515
stride_text_emb_s,
1616
stride_text_emb_d, # text_stride
17-
stride_img_emb_s,
18-
stride_img_emb_d, # img_stride
17+
stride_emb_cache_s,
18+
stride_emb_cache_l,
19+
stride_emb_cache_d, # img_stride
1920
stride_out_s,
2021
stride_out_d,
2122
tp_text_start_token_id,
2223
tp_text_end_token_id,
2324
hidden_size,
25+
tp_world_size,
2426
BLOCK_HIDDEN_DIM: tl.constexpr,
2527
):
2628

@@ -44,7 +46,7 @@ def _fwd_kernel(
4446
tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size)
4547

4648
img_start_token_id = tl.load(Img_start_token_ids + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
47-
img_start_loc = tl.load(Img_start_locs + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
49+
img_start_loc = tl.load(Img_start_locs_in_cache + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
4850
img_token_len = tl.load(Img_token_lens + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
4951
# load store img emb
5052
for _ in range(
@@ -57,11 +59,16 @@ def _fwd_kernel(
5759
1,
5860
):
5961
load_emb = tl.load(
60-
Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d,
62+
Embed_cache
63+
+ stride_emb_cache_s.to(tl.int64) * (img_start_loc + token_id - img_start_token_id)
64+
+ stride_emb_cache_l * 0
65+
+ stride_emb_cache_d * off_d,
6166
mask=off_d < hidden_size,
6267
other=0,
6368
)
64-
tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size)
69+
tl.store(
70+
Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb / tp_world_size, mask=off_d < hidden_size
71+
)
6572
return
6673

6774

@@ -70,35 +77,38 @@ def multimodal_emb(
7077
out: torch.Tensor,
7178
prompt_ids: torch.Tensor,
7279
text_weight_embs: torch.Tensor,
73-
img_embs: torch.Tensor,
80+
embed_cache: torch.Tensor,
7481
img_token_lens: torch.Tensor,
7582
img_start_token_ids: torch.Tensor,
76-
img_start_locs: torch.Tensor,
77-
tp_text_start_token_id,
78-
tp_text_end_token_id,
83+
img_start_locs_in_cache: torch.Tensor,
84+
tp_text_start_token_id: int,
85+
tp_text_end_token_id: int,
86+
tp_world_size: int,
7987
):
8088
total_len = prompt_ids.shape[0]
8189
BLOCK = triton.next_power_of_2(out.shape[1])
8290
# print(len(img_token_lens))
8391
grid = (total_len, len(img_token_lens) + 1)
8492
num_warps = 1
8593
_fwd_kernel[grid](
86-
prompt_ids,
87-
text_weight_embs,
88-
img_embs,
89-
out,
90-
img_token_lens,
91-
img_start_token_ids,
92-
img_start_locs,
93-
text_weight_embs.stride(0),
94-
text_weight_embs.stride(1),
95-
img_embs.stride(0),
96-
img_embs.stride(1),
97-
out.stride(0),
98-
out.stride(1),
99-
tp_text_start_token_id,
100-
tp_text_end_token_id,
94+
Prompt_ids=prompt_ids,
95+
Text_weight_embs=text_weight_embs,
96+
Embed_cache=embed_cache,
97+
Out=out,
98+
Img_token_lens=img_token_lens,
99+
Img_start_token_ids=img_start_token_ids,
100+
Img_start_locs_in_cache=img_start_locs_in_cache,
101+
stride_text_emb_s=text_weight_embs.stride(0),
102+
stride_text_emb_d=text_weight_embs.stride(1),
103+
stride_emb_cache_s=embed_cache.stride(0),
104+
stride_emb_cache_l=embed_cache.stride(1),
105+
stride_emb_cache_d=embed_cache.stride(2),
106+
stride_out_s=out.stride(0),
107+
stride_out_d=out.stride(1),
108+
tp_text_start_token_id=tp_text_start_token_id,
109+
tp_text_end_token_id=tp_text_end_token_id,
101110
hidden_size=out.shape[1],
111+
tp_world_size=float(tp_world_size),
102112
BLOCK_HIDDEN_DIM=BLOCK,
103113
num_warps=num_warps,
104114
num_stages=1,

lightllm/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from lightllm.models.internvl.model import InternVLInternlm2TpPartModel
3030
from lightllm.models.qwen2_vl.model import Qwen2VLTpPartModel
3131
from lightllm.models.qwen2_reward.model import Qwen2RewardTpPartModel
32+
from lightllm.models.qwen3_vl.model import Qwen3VLTpPartModel
33+
from lightllm.models.qwen3_vl_moe.model import Qwen3VLMOETpPartModel
3234
from lightllm.models.gemma3.model import Gemma3TpPartModel
3335
from lightllm.models.tarsier2.model import (
3436
Tarsier2Qwen2TpPartModel,

lightllm/models/gemma3/layer_infer/pre_layer_infer.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
33
from lightllm.distributed.communication_op import all_reduce
44
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
5-
from lightllm.server.embed_cache.utils import bytes2tensor, get_shm_name_embed, read_shm
65

76

87
class Gemma3PreLayerInfer(LlamaMultimodalPreLayerInfer):
@@ -14,16 +13,15 @@ def __init__(self, network_config, mode):
1413
return
1514

1615
def context_forward(self, input_ids, infer_state, layer_weight):
17-
img_weight = []
1816
img_start_token_ids = []
1917
img_token_lens = []
20-
img_start_loc = 0
21-
img_start_locs = []
18+
img_start_locs_in_cache = []
2219
device = layer_weight.wte_weight_.device
2320
dtype = layer_weight.wte_weight_.dtype
2421
hidden_size = layer_weight.wte_weight_.shape[1]
2522
weight_mask = torch.zeros((len(input_ids)), dtype=torch.float32, device=device)
2623

24+
# TODO
2725
scale = self.embed_scale
2826
for idx, input_id in enumerate(input_ids):
2927
if input_id == self.boi_token_index:
@@ -35,45 +33,46 @@ def context_forward(self, input_ids, infer_state, layer_weight):
3533
else:
3634
weight_mask[idx] = scale
3735

38-
infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids)
39-
4036
for batch_id, p in enumerate(infer_state.multimodal_params):
4137
for img in p["images"]:
4238
# skip the same image
43-
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
39+
if img["token_id"] in img_start_token_ids:
4440
continue
45-
# pull the img_embeds by uid from shm
46-
data = read_shm(get_shm_name_embed(img["uuid"]))
47-
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1))
4841
img_start_token_ids.append(img["token_id"])
4942
img_token_lens.append(img["token_num"])
50-
img_start_locs.append(img_start_loc)
51-
img_start_loc += img["token_num"]
43+
img_start_locs_in_cache.append(img["start_index_in_embed_cache"])
5244
out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device)
53-
if len(img_weight) > 0:
54-
img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype)
55-
else:
56-
img_weight = torch.empty((0, hidden_size), device=device, dtype=dtype)
57-
assert img_weight.shape[1] == hidden_size, (
45+
46+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
47+
48+
cpu_embed_cache_tensor = g_infer_context.cpu_embed_cache_client.cpu_embed_cache_tensor
49+
50+
assert cpu_embed_cache_tensor.shape[2] == hidden_size, (
5851
f"Dimension mismatch: text weight dimension is {hidden_size}, "
59-
f"but image weight dimension is {img_weight.shape[1]}"
52+
f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}"
6053
)
6154
# each tp will fill the img embeds, should divide by world_size
62-
img_weight = img_weight / self.tp_world_size_
63-
img_start_token_ids = torch.Tensor(img_start_token_ids).to(device=device, dtype=torch.long)
64-
img_token_lens = torch.Tensor(img_token_lens).to(device=device, dtype=torch.long)
65-
img_start_locs = torch.Tensor(img_start_locs).to(device=device, dtype=torch.long)
55+
img_start_token_ids = torch.tensor(img_start_token_ids, dtype=torch.long, device="cpu", pin_memory=True).cuda(
56+
non_blocking=True
57+
)
58+
img_token_lens = torch.tensor(img_token_lens, dtype=torch.long, device="cpu", pin_memory=True).cuda(
59+
non_blocking=True
60+
)
61+
img_start_locs_in_cache = torch.tensor(
62+
img_start_locs_in_cache, dtype=torch.long, device="cpu", pin_memory=True
63+
).cuda(non_blocking=True)
6664

6765
multimodal_emb(
68-
out,
69-
input_ids,
70-
layer_weight.wte_weight_,
71-
img_weight,
72-
img_token_lens,
73-
img_start_token_ids,
74-
img_start_locs,
75-
self.vob_start_id_,
76-
self.vob_end_id_,
66+
out=out,
67+
prompt_ids=input_ids,
68+
text_weight_embs=layer_weight.wte_weight_,
69+
embed_cache=cpu_embed_cache_tensor,
70+
img_token_lens=img_token_lens,
71+
img_start_token_ids=img_start_token_ids,
72+
img_start_locs_in_cache=img_start_locs_in_cache,
73+
tp_text_start_token_id=self.vob_start_id_,
74+
tp_text_end_token_id=self.vob_end_id_,
75+
tp_world_size=self.tp_world_size_,
7776
)
7877
input_dtype = out.dtype
7978
if self.tp_world_size_ > 1:

lightllm/models/qwen2_vl/infer_struct.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
3333
self.position_ids = position_ids.unsqueeze(0).expand(3, -1)
3434

3535
self.position_ids = self.position_ids.contiguous()
36-
self.position_cos = model._cos_cached[self.position_ids] # (3, L, D)
37-
self.position_sin = model._sin_cached[self.position_ids] # (3, L, D)
36+
self.position_cos = model._cos_cached[self.position_ids]
37+
self.position_sin = model._sin_cached[self.position_ids]
3838
if get_env_start_args().enable_fa3:
3939
self.max_seq_len = self.max_kv_seq_len
4040
self.q_max_seq_len = self.max_q_seq_len
@@ -66,7 +66,7 @@ def get_mrope_position(self, multimodal_params: List[dict]) -> torch.Tensor:
6666
b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4
6767
b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True)
6868
b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True)
69-
b_image_len = torch.tensor(b_image_len, device=self.position_ids.device)
69+
b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True)
7070
position_ids = self.position_ids.unsqueeze(0).expand(3, -1).contiguous()
7171
get_mrope_position_triton(
7272
b_image_start_idx=b_image_start_idx,

lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,28 @@
55
from typing import Tuple
66
from functools import partial
77

8-
from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton
8+
from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused
99
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
1010

1111

1212
class Qwen2VLTransformerLayerInfer(LlamaTransformerLayerInfer):
1313
def __init__(self, layer_num, network_config, mode=[]):
1414
super().__init__(layer_num, network_config, mode)
15-
self.mrope_section = network_config["rope_scaling"]["mrope_section"]
16-
axis_map = []
17-
for i, n in enumerate(self.mrope_section * 2):
18-
axis_map += [i % 3] * n
19-
self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda")
15+
mrope_section = network_config["rope_scaling"]["mrope_section"]
16+
self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda")
2017

2118
def _get_qkv(self, input, infer_state, layer_weight):
2219
q = layer_weight.q_proj.mm(input)
2320
cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
24-
seq_len, _ = q.shape
25-
q = q.view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
26-
self.axis_map = self.axis_map.to(q.device)
27-
k = cache_kv[:, : self.tp_k_head_num_, :].view(1, seq_len, -1, self.head_dim_).transpose(1, 2)
28-
new_q, new_k = mrope_triton(q, k, infer_state.position_cos, infer_state.position_sin, self.axis_map)
29-
new_q = new_q.transpose(1, 2).reshape(1, seq_len, -1)
30-
cache_kv[:, : self.tp_k_head_num_, :] = new_k.squeeze(0).permute(1, 0, 2)
31-
32-
return new_q, cache_kv
21+
mrope_triton_fused(
22+
q.view(-1, self.tp_q_head_num_, self.head_dim_),
23+
cache_kv[:, : self.tp_k_head_num_, :],
24+
infer_state.position_cos,
25+
infer_state.position_sin,
26+
self.mrope_section,
27+
is_interleaved=False,
28+
)
29+
return q, cache_kv
3330

3431
def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
3532
# TODO

lightllm/models/qwen2_vl/triton_kernel/get_mrope_position_ids.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,34 @@ def test():
138138
b_q_seq_len,
139139
b_start_loc,
140140
)
141-
print(position_ids)
141+
142+
# print(position_ids)
143+
old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1)
144+
145+
position_ids = (
146+
torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda")
147+
.unsqueeze(0)
148+
.expand(3, -1)
149+
.contiguous()
150+
)
151+
b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda")
152+
b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda")
153+
b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda")
154+
155+
get_mrope_position_triton(
156+
b_image_start_idx,
157+
b_image_thwd,
158+
b_image_nums,
159+
b_image_start_num,
160+
b_image_len,
161+
position_ids,
162+
b_ready_cache_len,
163+
b_q_seq_len,
164+
b_start_loc,
165+
)
166+
167+
assert torch.equal(old_value, position_ids)
168+
142169
"""
143170
tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8],
144171
[0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8],

0 commit comments

Comments
 (0)