Skip to content

Commit 0da89eb

Browse files
author
sangchengmeng
committed
1210
1 parent 3ee963e commit 0da89eb

File tree

8 files changed

+148
-86
lines changed

8 files changed

+148
-86
lines changed
Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,94 @@
1+
import torch
2+
import torch.distributed as dist
3+
4+
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
5+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
6+
7+
from lightllm.server.embed_cache.utils import (
8+
bytes2tensor,
9+
read_shm,
10+
get_shm_name_embed,
11+
)
12+
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
13+
from lightllm.distributed.communication_op import all_reduce
14+
115
from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
216

317

418
class Qwen3VLMultimodalPreLayerInfer(LlamaMultimodalPreLayerInfer):
519
def __init__(self, network_config, mode):
620
super().__init__(network_config, mode)
7-
self.use_deepstack = True
821
return
22+
23+
def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
24+
25+
img_weight = []
26+
img_start_token_ids = []
27+
img_token_lens = []
28+
img_start_loc = 0
29+
img_start_locs = []
30+
31+
device = layer_weight.wte_weight_.device
32+
dtype = layer_weight.wte_weight_.dtype
33+
hidden_size = layer_weight.wte_weight_.shape[1]
34+
35+
infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids)
36+
37+
for batch_id, p in enumerate(infer_state.multimodal_params):
38+
for img in p["images"] + p["audios"]:
39+
# skip the same image
40+
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
41+
continue
42+
pos = (input_ids == img["token_id"]).nonzero(as_tuple=True)
43+
if pos[0].numel() == 0:
44+
continue
45+
# pull the img_embeds by uid from shm
46+
all_img_embed_df = bytes2tensor(read_shm(get_shm_name_embed(img["uuid"])))
47+
per_image_deepstack = []
48+
49+
deepstack_layer_num = all_img_embed_df.shape[0] // img["token_num"] - 1
50+
img_weight.append(all_img_embed_df[: img["token_num"]].cuda())
51+
52+
for layer in range(deepstack_layer_num):
53+
start = img["token_num"] * (layer + 1)
54+
end = img["token_num"] * (layer + 2)
55+
per_image_deepstack.append(all_img_embed_df[start:end])
56+
57+
infer_state.deepstack_features.append(per_image_deepstack)
58+
img_insert_locs = int(pos[0][0])
59+
infer_state.img_first_token_locs.append(img_insert_locs)
60+
infer_state.img_last_token_locs.append(img_insert_locs + img["token_num"])
61+
img_start_token_ids.append(img["token_id"])
62+
img_token_lens.append(img["token_num"])
63+
img_start_locs.append(img_start_loc)
64+
img_start_loc += img["token_num"]
65+
out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device)
66+
67+
if len(img_weight) > 0:
68+
img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype)
69+
else:
70+
img_weight = torch.empty((0, hidden_size), device=device, dtype=dtype)
71+
assert img_weight.shape[1] == hidden_size, (
72+
f"Dimension mismatch: text weight dimension is {hidden_size}, "
73+
f"but image weight dimension is {img_weight.shape[1]}"
74+
)
75+
# each tp will fill the img embeds, should divide by world_size
76+
img_weight = img_weight / self.tp_world_size_
77+
img_start_token_ids = torch.Tensor(img_start_token_ids).to(device=device, dtype=torch.long)
78+
img_token_lens = torch.Tensor(img_token_lens).to(device=device, dtype=torch.long)
79+
img_start_locs = torch.Tensor(img_start_locs).to(device=device, dtype=torch.long)
80+
81+
multimodal_emb(
82+
out,
83+
input_ids,
84+
layer_weight.wte_weight_,
85+
img_weight,
86+
img_token_lens,
87+
img_start_token_ids,
88+
img_start_locs,
89+
self.vob_start_id_,
90+
self.vob_end_id_,
91+
)
92+
if self.tp_world_size_ > 1:
93+
all_reduce(out, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False)
94+
return out

lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,15 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la
5555
deepstack_features_cur_layer = deepstack_features[self.layer_num_].to(
5656
device=input_embdings.device, non_blocking=True
5757
)
58+
print(
59+
f"self.layer_num_ is {self.layer_num_}, i is{i} ,"
60+
f"deepstack_features_cur_layer is {deepstack_features_cur_layer}"
61+
)
5862
input_embdings[
5963
start:end,
6064
].add_(deepstack_features_cur_layer)
61-
infer_state.img_first_token_locs = []
62-
infer_state.img_last_token_locs = []
63-
infer_state.deepstack_features = []
65+
if self.layer_num_ == len(deepstack_features):
66+
infer_state.img_first_token_locs = []
67+
infer_state.img_last_token_locs = []
68+
infer_state.deepstack_features = []
6469
return input_embdings

lightllm/models/qwen3_vl/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def get_image_token_length(self, img: ImageItem):
6262
)
6363
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
6464
token_num = (grid_h * grid_w) // (self.merge_size ** 2)
65+
print(f"token_num is {token_num}")
6566
return token_num
6667

6768
def get_audio_token_length(self, audio: AudioItem):

lightllm/models/qwen3_vl/qwen3_visual.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import os
1717
import json
18+
import time
1819
from PIL import Image
1920
from io import BytesIO
2021
from typing import List
@@ -67,7 +68,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6768
hidden_states = hidden_states.view(
6869
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
6970
)
71+
# num_patches = hidden_states.shape[0]
72+
# print(f"num_patches is {num_patches}")
73+
# torch.cuda.synchronize()
74+
# time0 = time.perf_counter()
7075
hidden_states = self.proj(hidden_states).view(-1, self.embed_dim)
76+
# torch.cuda.synchronize()
77+
# print(f"patch embed time is {time.perf_counter()-time0}")
7178
return hidden_states
7279

7380

@@ -194,6 +201,39 @@ def _init_datatype(self):
194201
raise ValueError(f"Unsupport datatype {self.data_type}!")
195202
return
196203

204+
def concat_img_embed_and_deepstack_features(self, image_embed, deepstack_feature_lists, valid_ids):
205+
# input: image_embed: [img_embed1, img_embed2, img_embed3]
206+
# deepstack_feature_lists:[df1-1, df1-2, df1-3,
207+
# df2-1, df2-2, df2-3,
208+
# df3-1, df3-2, df3-3]
209+
# valid_ids:[[start_1, end_1], [start_2, end_2], [start_3, end_3]]
210+
#
211+
# return: all_img_embeds_ds: [img_embed1, df1-1, df1-2, df1-3,
212+
# img_embed2, df2-1, df2-2, df2-3,
213+
# img_embed3, df3-1, df3-2, df3-3]
214+
# valid_ids:[[start_1, end_1], [start_2, end_2], [start_3, end_3]] # image_embed的start和end
215+
all_chunks = []
216+
new_valid_ids = []
217+
218+
row_offset = 0
219+
220+
for start, end in valid_ids:
221+
hs_i = image_embed[start:end]
222+
ds_i_list = [feat[start:end] for feat in deepstack_feature_lists]
223+
224+
combined_i = torch.cat([hs_i, *ds_i_list], dim=0)
225+
226+
new_start = row_offset
227+
new_end = row_offset + combined_i.size(0)
228+
new_valid_ids.append([new_start, new_end])
229+
230+
all_chunks.append(combined_i)
231+
232+
row_offset += new_end
233+
234+
all_img_embeds_ds = torch.cat(all_chunks, dim=0)
235+
return all_img_embeds_ds, new_valid_ids
236+
197237
def load_model(self, weight_dir):
198238

199239
processor_config_path = os.path.join(weight_dir, "preprocessor_config.json")
@@ -320,21 +360,17 @@ def fast_pos_embed_interpolate(self, grid_thw):
320360

321361
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
322362
hidden_states = self.patch_embed(hidden_states)
323-
324363
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
325364
hidden_states = hidden_states + pos_embeds
326-
327365
rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw)
328366
rotary_cos = rotary_cos.to("cuda", non_blocking=True)
329367
rotary_sin = rotary_sin.to("cuda", non_blocking=True)
330-
331368
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
332369
dim=0,
333370
dtype=torch.int32,
334371
)
335372
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to("cuda", non_blocking=True)
336373
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
337-
338374
deepstack_feature_lists = []
339375
for layer_num, blk in enumerate(self.blocks):
340376
hidden_states = blk(
@@ -349,6 +385,7 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs)
349385
hidden_states
350386
)
351387
deepstack_feature_lists.append(deepstack_feature)
388+
# print(f"ds time is {time.perf_counter()-time0}")
352389

353390
hidden_states = self.merger(hidden_states)
354391

@@ -391,7 +428,9 @@ def encode(self, images: List[ImageItem]):
391428

392429
pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True)
393430
image_grid_thw = grid_thw.to("cuda", non_blocking=True)
431+
img_embeds, deepstack_feature_lists = self.forward(pixel_values, grid_thw=image_grid_thw)
432+
all_img_embeds_df, valid_ids = self.concat_img_embed_and_deepstack_features(
433+
img_embeds, deepstack_feature_lists, valid_ids
434+
)
394435

395-
all_img_embeds, deepstack_feature_lists = self.forward(pixel_values, grid_thw=image_grid_thw)
396-
397-
return all_img_embeds, uuids, valid_ids, deepstack_feature_lists
436+
return all_img_embeds_df, uuids, valid_ids

lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def context_forward(self, input_embdings, infer_state: Qwen3VLInferStateInfo, la
8787
input_embdings[
8888
start:end,
8989
].add_(deepstack_features_cur_layer)
90-
infer_state.img_first_token_locs = []
91-
infer_state.img_last_token_locs = []
92-
infer_state.deepstack_features = []
90+
if self.layer_num_ == len(deepstack_features):
91+
infer_state.img_first_token_locs = []
92+
infer_state.img_last_token_locs = []
93+
infer_state.deepstack_features = []
9394
return input_embdings

lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,7 @@
66

77
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
88
from lightllm.utils.infer_utils import mark_cost_time
9-
from lightllm.server.embed_cache.utils import (
10-
bytes2tensor,
11-
read_shm,
12-
get_shm_name_embed,
13-
get_shm_name_deepstack,
14-
bytes2list,
15-
)
9+
from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed
1610
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
1711
from lightllm.distributed.communication_op import all_reduce
1812

@@ -35,7 +29,6 @@
3529
class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer):
3630
def __init__(self, network_config, mode):
3731
super().__init__(network_config, mode)
38-
self.use_deepstack = False
3932
return
4033

4134
def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
@@ -57,18 +50,9 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
5750
# skip the same image
5851
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
5952
continue
60-
pos = (input_ids == img["token_id"]).nonzero(as_tuple=True)
61-
if pos[0].numel() == 0:
62-
continue
6353
# pull the img_embeds by uid from shm
6454
data = read_shm(get_shm_name_embed(img["uuid"]))
6555
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1))
66-
if self.use_deepstack:
67-
deepstack_features = read_shm(get_shm_name_deepstack(img["uuid"]))
68-
infer_state.deepstack_features.append(bytes2list(deepstack_features))
69-
img_insert_locs = int(pos[0][0])
70-
infer_state.img_first_token_locs.append(img_insert_locs)
71-
infer_state.img_last_token_locs.append(img_insert_locs + img["token_num"])
7256
img_start_token_ids.append(img["token_id"])
7357
img_token_lens.append(img["token_num"])
7458
img_start_locs.append(img_start_loc)

lightllm/server/embed_cache/utils.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,52 +21,11 @@ def tensor2bytes(t: torch.Tensor):
2121
return buf.read()
2222

2323

24-
def list2bytes(tensors: List[torch.Tensor]) -> bytes:
25-
# 逐个张量做 detach().cpu() 和复制
26-
safe_list = []
27-
for t in tensors:
28-
if t is None:
29-
safe_list.append(None)
30-
continue
31-
t = t.detach().cpu()
32-
if not t.is_contiguous():
33-
t = t.contiguous()
34-
dest = torch.empty_like(t)
35-
dest.copy_(t)
36-
safe_list.append(dest)
37-
buf = BytesIO()
38-
torch.save(safe_list, buf, _use_new_zipfile_serialization=False, pickle_protocol=4)
39-
buf.seek(0)
40-
return buf.read()
41-
42-
4324
def bytes2tensor(b):
4425
# return torch.from_numpy(np.frombuffer(b, dtype=np.float16)).cuda()
4526
return torch.load(BytesIO(b), weights_only=False)
4627

4728

48-
def bytes2list(b: bytes, device: Optional[torch.device] = None, non_blocking: bool = False) -> List[torch.Tensor]:
49-
obj = torch.load(BytesIO(b), map_location="cpu", weights_only=False)
50-
51-
if isinstance(obj, tuple):
52-
obj = list(obj)
53-
if not isinstance(obj, list):
54-
raise TypeError(f"Loaded object is {type(obj)}, expected list or tuple.")
55-
56-
if device is None:
57-
return obj
58-
59-
out: List[torch.Tensor] = []
60-
for x in obj:
61-
if x is None:
62-
out.append(None)
63-
elif isinstance(x, torch.Tensor):
64-
out.append(x.to(device, non_blocking=non_blocking))
65-
else:
66-
raise TypeError(f"List element is {type(x)}, expected Tensor or None.")
67-
return out
68-
69-
7029
def create_shm(name, data):
7130
try:
7231
data_size = len(data)
@@ -95,7 +54,3 @@ def get_shm_name_data(uid):
9554

9655
def get_shm_name_embed(uid):
9756
return str(uid) + "-embed"
98-
99-
100-
def get_shm_name_deepstack(uid):
101-
return str(uid) + "-deepstack"

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,8 @@
2121
from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel
2222
from lightllm.server.embed_cache.utils import (
2323
tensor2bytes,
24-
read_shm,
2524
create_shm,
26-
get_shm_name_data,
2725
get_shm_name_embed,
28-
get_shm_name_deepstack,
29-
list2bytes,
3026
)
3127
from lightllm.utils.infer_utils import set_random_seed
3228
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
@@ -111,8 +107,7 @@ def forward(self, images: List[ImageItem]):
111107
# @calculate_time(show=False, min_cost_ms=300)
112108
def exposed_encode(self, images: List[ImageItem]):
113109
images = obtain(images)
114-
all_img_embeds, uuids, valid_ids, *deepstack_features = self.forward(images)
115-
deepstack_feature_lists = deepstack_features[0] if deepstack_features else None
110+
all_img_embeds, uuids, valid_ids = self.forward(images)
116111
all_img_embeds = all_img_embeds.to(torch.device("cpu"))
117112

118113
if self.tp_rank_id == 0:
@@ -125,10 +120,6 @@ def exposed_encode(self, images: List[ImageItem]):
125120
start, end = valid_ids[i]
126121
cur_embed_bytes = tensor2bytes(all_img_embeds[start:end])
127122
create_shm(get_shm_name_embed(uid), cur_embed_bytes)
128-
if deepstack_feature_lists is not None:
129-
per_image_deepstack = [feat[start:end] for feat in deepstack_feature_lists]
130-
deepstack_features_bytes = list2bytes(per_image_deepstack)
131-
create_shm(get_shm_name_deepstack(uid), deepstack_features_bytes)
132123
ids_to_set.append(uid)
133124
if ids_to_set:
134125
self.cache_client.root.set_items_embed(ids_to_set)

0 commit comments

Comments
 (0)