Skip to content

Commit 02486eb

Browse files
author
sangchengmeng
committed
1210
1 parent 0da89eb commit 02486eb

File tree

6 files changed

+32
-65
lines changed

6 files changed

+32
-65
lines changed

lightllm/models/qwen3_vl/infer_struct.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
class Qwen3VLInferStateInfo(LlamaInferStateInfo):
88
def __init__(self):
99
super().__init__()
10+
self.input_ids = None
1011
self.deepstack_features = []
11-
self.img_first_token_locs = []
12-
self.img_last_token_locs = []
12+
self.deepstack_end_layer = None
13+
self.img_start_token_ids = []
14+
self.img_token_lens = []
15+
self.img_start_locs = []
1316

1417
def apply_interleaved_mrope(self, freqs, mrope_section):
1518
"""Apply interleaved MRoPE to 3D rotary embeddings.

lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
55
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
6+
from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
67

78
from lightllm.server.embed_cache.utils import (
89
bytes2tensor,
@@ -20,13 +21,14 @@ def __init__(self, network_config, mode):
2021
super().__init__(network_config, mode)
2122
return
2223

23-
def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
24-
24+
def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
2525
img_weight = []
26-
img_start_token_ids = []
27-
img_token_lens = []
2826
img_start_loc = 0
29-
img_start_locs = []
27+
28+
infer_state.input_ids = input_ids
29+
infer_state.img_start_token_ids = []
30+
infer_state.img_token_lens = []
31+
infer_state.img_start_locs = []
3032

3133
device = layer_weight.wte_weight_.device
3234
dtype = layer_weight.wte_weight_.dtype
@@ -37,12 +39,9 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
3739
for batch_id, p in enumerate(infer_state.multimodal_params):
3840
for img in p["images"] + p["audios"]:
3941
# skip the same image
40-
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
42+
if img["token_id"] in infer_state.img_start_token_ids or img["_prefill_"] is False:
4143
continue
42-
pos = (input_ids == img["token_id"]).nonzero(as_tuple=True)
43-
if pos[0].numel() == 0:
44-
continue
45-
# pull the img_embeds by uid from shm
44+
4645
all_img_embed_df = bytes2tensor(read_shm(get_shm_name_embed(img["uuid"])))
4746
per_image_deepstack = []
4847

@@ -55,12 +54,9 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
5554
per_image_deepstack.append(all_img_embed_df[start:end])
5655

5756
infer_state.deepstack_features.append(per_image_deepstack)
58-
img_insert_locs = int(pos[0][0])
59-
infer_state.img_first_token_locs.append(img_insert_locs)
60-
infer_state.img_last_token_locs.append(img_insert_locs + img["token_num"])
61-
img_start_token_ids.append(img["token_id"])
62-
img_token_lens.append(img["token_num"])
63-
img_start_locs.append(img_start_loc)
57+
infer_state.img_start_token_ids.append(img["token_id"])
58+
infer_state.img_token_lens.append(img["token_num"])
59+
infer_state.img_start_locs.append(img_start_loc)
6460
img_start_loc += img["token_num"]
6561
out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device)
6662

@@ -74,9 +70,9 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
7470
)
7571
# each tp will fill the img embeds, should divide by world_size
7672
img_weight = img_weight / self.tp_world_size_
77-
img_start_token_ids = torch.Tensor(img_start_token_ids).to(device=device, dtype=torch.long)
78-
img_token_lens = torch.Tensor(img_token_lens).to(device=device, dtype=torch.long)
79-
img_start_locs = torch.Tensor(img_start_locs).to(device=device, dtype=torch.long)
73+
img_start_token_ids = torch.Tensor(infer_state.img_start_token_ids).to(device=device, dtype=torch.long)
74+
img_token_lens = torch.Tensor(infer_state.img_token_lens).to(device=device, dtype=torch.long)
75+
img_start_locs = torch.Tensor(infer_state.img_start_locs).to(device=device, dtype=torch.long)
8076

8177
multimodal_emb(
8278
out,

lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
1717
from lightllm.distributed import all_reduce
1818
from lightllm.utils.dist_utils import get_global_world_size
19+
from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features
1920

2021

2122
class Qwen3VLTransformerLayerInfer(Qwen3TransformerLayerInfer):
@@ -46,24 +47,9 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la
4647
if self.tp_world_size_ > 1:
4748
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
4849
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
49-
if infer_state.deepstack_features:
50-
for i in range(len(infer_state.img_first_token_locs)):
51-
start = infer_state.img_first_token_locs[i]
52-
end = infer_state.img_last_token_locs[i]
53-
deepstack_features = infer_state.deepstack_features[i]
54-
if end <= input_embdings.shape[0] and self.layer_num_ in range(len(deepstack_features)):
55-
deepstack_features_cur_layer = deepstack_features[self.layer_num_].to(
56-
device=input_embdings.device, non_blocking=True
57-
)
58-
print(
59-
f"self.layer_num_ is {self.layer_num_}, i is{i} ,"
60-
f"deepstack_features_cur_layer is {deepstack_features_cur_layer}"
61-
)
62-
input_embdings[
63-
start:end,
64-
].add_(deepstack_features_cur_layer)
65-
if self.layer_num_ == len(deepstack_features):
66-
infer_state.img_first_token_locs = []
67-
infer_state.img_last_token_locs = []
68-
infer_state.deepstack_features = []
50+
apply_deepstack_features(
51+
input_embeddings=input_embdings,
52+
infer_state=infer_state,
53+
layer_num=self.layer_num_,
54+
)
6955
return input_embdings

lightllm/models/qwen3_vl/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def get_image_token_length(self, img: ImageItem):
6262
)
6363
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
6464
token_num = (grid_h * grid_w) // (self.merge_size ** 2)
65-
print(f"token_num is {token_num}")
6665
return token_num
6766

6867
def get_audio_token_length(self, audio: AudioItem):

lightllm/models/qwen3_vl/qwen3_visual.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6868
hidden_states = hidden_states.view(
6969
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
7070
)
71-
# num_patches = hidden_states.shape[0]
72-
# print(f"num_patches is {num_patches}")
73-
# torch.cuda.synchronize()
74-
# time0 = time.perf_counter()
7571
hidden_states = self.proj(hidden_states).view(-1, self.embed_dim)
76-
# torch.cuda.synchronize()
77-
# print(f"patch embed time is {time.perf_counter()-time0}")
7872
return hidden_states
7973

8074

@@ -385,7 +379,6 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs)
385379
hidden_states
386380
)
387381
deepstack_feature_lists.append(deepstack_feature)
388-
# print(f"ds time is {time.perf_counter()-time0}")
389382

390383
hidden_states = self.merger(hidden_states)
391384

lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
1717
from lightllm.distributed import all_reduce
1818
from lightllm.utils.dist_utils import get_global_world_size
19+
from lightllm.models.qwen3_vl.triton_kernel.deepstack_multimodal_emb import apply_deepstack_features
1920

2021

2122
class Qwen3VLMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer):
@@ -75,20 +76,9 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la
7576
if self.tp_world_size_ > 1:
7677
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
7778
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
78-
if infer_state.deepstack_features:
79-
for i in range(len(infer_state.img_first_token_locs)):
80-
start = infer_state.img_first_token_locs[i]
81-
end = infer_state.img_last_token_locs[i]
82-
deepstack_features = infer_state.deepstack_features[i]
83-
if end <= input_embdings.shape[0] and self.layer_num_ in range(len(deepstack_features)):
84-
deepstack_features_cur_layer = deepstack_features[self.layer_num_].to(
85-
device=input_embdings.device, non_blocking=True
86-
)
87-
input_embdings[
88-
start:end,
89-
].add_(deepstack_features_cur_layer)
90-
if self.layer_num_ == len(deepstack_features):
91-
infer_state.img_first_token_locs = []
92-
infer_state.img_last_token_locs = []
93-
infer_state.deepstack_features = []
79+
apply_deepstack_features(
80+
input_embeddings=input_embdings,
81+
infer_state=infer_state,
82+
layer_num=self.layer_num_,
83+
)
9484
return input_embdings

0 commit comments

Comments
 (0)