Skip to content

Commit 87cd0fa

Browse files
committed
qwen3-vl cpu embed
1 parent 0b7ca92 commit 87cd0fa

File tree

2 files changed

+19
-21
lines changed

2 files changed

+19
-21
lines changed

lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,20 @@ def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_w
3434
from lightllm.server.router.model_infer.infer_batch import g_infer_context
3535

3636
cpu_embed_cache_tensor = g_infer_context.cpu_embed_cache_client.cpu_embed_cache_tensor
37+
infer_state.cpu_embed_cache_tensor = cpu_embed_cache_tensor
3738

3839
assert cpu_embed_cache_tensor.shape[2] == hidden_size, (
3940
f"Dimension mismatch: text weight dimension is {hidden_size}, "
4041
f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}"
4142
)
4243
# each tp will fill the img embeds, should divide by world_size
43-
img_start_token_ids = torch.tensor(img_start_token_ids, dtype=torch.long, device="cpu", pin_memory=True).cuda(
44-
non_blocking=True
45-
)
46-
img_token_lens = torch.tensor(img_token_lens, dtype=torch.long, device="cpu", pin_memory=True).cuda(
44+
infer_state.img_start_token_ids = torch.tensor(
45+
img_start_token_ids, dtype=torch.long, device="cpu", pin_memory=True
46+
).cuda(non_blocking=True)
47+
infer_state.img_token_lens = torch.tensor(img_token_lens, dtype=torch.long, device="cpu", pin_memory=True).cuda(
4748
non_blocking=True
4849
)
49-
img_start_locs_in_cache = torch.tensor(
50+
infer_state.img_start_locs_in_cache = torch.tensor(
5051
img_start_locs_in_cache, dtype=torch.long, device="cpu", pin_memory=True
5152
).cuda(non_blocking=True)
5253

@@ -55,9 +56,9 @@ def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_w
5556
prompt_ids=input_ids,
5657
text_weight_embs=layer_weight.wte_weight_,
5758
embed_cache=cpu_embed_cache_tensor,
58-
img_token_lens=img_token_lens,
59-
img_start_token_ids=img_start_token_ids,
60-
img_start_locs_in_cache=img_start_locs_in_cache,
59+
img_token_lens=infer_state.img_token_lens,
60+
img_start_token_ids=infer_state.img_start_token_ids,
61+
img_start_locs_in_cache=infer_state.img_start_locs_in_cache,
6162
tp_text_start_token_id=self.vob_start_id_,
6263
tp_text_end_token_id=self.vob_end_id_,
6364
tp_world_size=self.tp_world_size_,

lightllm/models/qwen3_vl/triton_kernel/deepstack_multimodal_emb.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def _deepstack_add_kernel(
1111
Out,
1212
Img_token_lens,
1313
Img_start_token_ids,
14-
Img_start_locs,
14+
Img_start_locs_in_cache,
1515
stride_deep_s,
1616
stride_deep_d,
1717
stride_out_s,
@@ -26,8 +26,8 @@ def _deepstack_add_kernel(
2626
off_d = tl.arange(0, BLOCK_DIM)
2727

2828
img_start_token_id = tl.load(Img_start_token_ids + img_handle_id)
29-
img_start_loc = tl.load(Img_start_locs + img_handle_id)
3029
img_token_len = tl.load(Img_token_lens + img_handle_id)
30+
img_start_loc_in_cache = tl.load(Img_start_locs_in_cache + img_handle_id)
3131

3232
# 判断当前 token 是否属于这个 image
3333
cond = (token_id >= img_start_token_id) & (token_id < img_start_token_id + img_token_len)
@@ -36,7 +36,7 @@ def _deepstack_add_kernel(
3636
token_offset = token_id - img_start_token_id
3737

3838
deep_row = tl.load(
39-
Deepstack_embs + stride_deep_s * (img_start_loc + token_offset) + off_d,
39+
Deepstack_embs + stride_deep_s * (img_start_loc_in_cache + token_offset) + off_d,
4040
mask=off_d < hidden_size,
4141
other=0,
4242
)
@@ -60,7 +60,7 @@ def add_deepstack_embs(
6060
deepstack_embs: torch.Tensor,
6161
img_token_lens: torch.Tensor,
6262
img_start_token_ids: torch.Tensor,
63-
img_start_locs: torch.Tensor,
63+
img_start_locs_in_cache: torch.Tensor,
6464
):
6565
assert input_ids.dim() == 1
6666
assert out.dim() == 2
@@ -79,7 +79,7 @@ def add_deepstack_embs(
7979
out,
8080
img_token_lens,
8181
img_start_token_ids,
82-
img_start_locs,
82+
img_start_locs_in_cache,
8383
deepstack_embs.stride(0),
8484
deepstack_embs.stride(1),
8585
out.stride(0),
@@ -105,20 +105,17 @@ def apply_deepstack_features(
105105
if not infer_state.deepstack_features:
106106
return
107107

108-
if layer_num >= len(infer_state.deepstack_features[0]):
109-
return
108+
deepstack_num_layers = infer_state.cpu_embed_cache_tensor.shape[1] - 1
110109

111-
per_img_deepstack_features = [
112-
infer_state.deepstack_features[i][layer_num] for i in range(infer_state.img_token_lens.shape[0])
113-
]
114-
all_deepstack_features = torch.cat(per_img_deepstack_features, dim=0)
110+
if layer_num >= deepstack_num_layers:
111+
return
115112

116113
add_deepstack_embs(
117114
out=input_embeddings,
118115
input_ids=infer_state.input_ids,
119-
deepstack_embs=all_deepstack_features,
116+
deepstack_embs=infer_state.cpu_embed_cache_tensor[:, layer_num + 1, :],
120117
img_token_lens=infer_state.img_token_lens,
121118
img_start_token_ids=infer_state.img_start_token_ids,
122-
img_start_locs=infer_state.img_start_locs,
119+
img_start_locs_in_cache=infer_state.img_start_locs_in_cache,
123120
)
124121
return

0 commit comments

Comments
 (0)