Skip to content

Commit f4233b6

Browse files
committed
fix pre layer infer
1 parent 789149f commit f4233b6

File tree

5 files changed

+79
-107
lines changed

5 files changed

+79
-107
lines changed

lightllm/common/basemodel/infer_struct.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -122,27 +122,6 @@ def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
122122
attr_.copy_(attr_value, non_blocking=True)
123123
return
124124

125-
def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor):
126-
"""
127-
功能函数,用于标记在chuncked prefill的过程中,到底哪些多模态对象对应的token是需要参与计算的。
128-
因为分chunck的原因,并不是所有的多模态对象对应的token都需要参与计算。
129-
"""
130-
multi_objs = []
131-
for _, p in enumerate(self.multimodal_params):
132-
for obj in p["images"] + p["audios"]:
133-
multi_objs.append(obj)
134-
135-
if multi_objs:
136-
obj_start_ids = torch.tensor([e["token_id"] for e in multi_objs], dtype=torch.int64, device="cuda")
137-
obj_token_lens = torch.tensor([e["token_num"] for e in multi_objs], dtype=torch.int64, device="cuda")
138-
marks = mark_multimodal_obj(
139-
obj_start_token_ids=obj_start_ids, obj_token_lens=obj_token_lens, input_ids=input_ids
140-
)
141-
marks_array = marks.detach().cpu().numpy()
142-
for mark, obj in zip(marks_array, multi_objs):
143-
obj["_prefill_"] = mark > 0
144-
return
145-
146125
def prefill_dp_balance(self, input_ids: torch.Tensor):
147126
"""
148127
在prefill的时候, 对于处于 dp 模式下的时候,对输入的数据进行重新的调整和分配,降低各个dp处理数据量过于不一致的时候,导致

lightllm/common/basemodel/triton_kernel/multimodal_emb.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,22 @@
77
def _fwd_kernel(
88
Prompt_ids,
99
Text_weight_embs,
10-
Img_embs,
10+
Embed_cache,
1111
Out,
1212
Img_token_lens,
1313
Img_start_token_ids,
14-
Img_start_locs,
14+
Img_start_locs_in_cache,
1515
stride_text_emb_s,
1616
stride_text_emb_d, # text_stride
17-
stride_img_emb_s,
18-
stride_img_emb_d, # img_stride
17+
stride_emb_cache_s,
18+
stride_emb_cache_l,
19+
stride_emb_cache_d, # img_stride
1920
stride_out_s,
2021
stride_out_d,
2122
tp_text_start_token_id,
2223
tp_text_end_token_id,
2324
hidden_size,
25+
tp_world_size,
2426
BLOCK_HIDDEN_DIM: tl.constexpr,
2527
):
2628

@@ -44,7 +46,7 @@ def _fwd_kernel(
4446
tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size)
4547

4648
img_start_token_id = tl.load(Img_start_token_ids + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
47-
img_start_loc = tl.load(Img_start_locs + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
49+
img_start_loc = tl.load(Img_start_locs_in_cache + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
4850
img_token_len = tl.load(Img_token_lens + img_handle_id - 1, mask=img_handle_id >= 1, other=0)
4951
# load store img emb
5052
for _ in range(
@@ -57,11 +59,16 @@ def _fwd_kernel(
5759
1,
5860
):
5961
load_emb = tl.load(
60-
Img_embs + stride_img_emb_s * (img_start_loc + token_id - img_start_token_id) + off_d * stride_img_emb_d,
62+
Embed_cache
63+
+ stride_emb_cache_s.to(tl.int64) * (img_start_loc + token_id - img_start_token_id)
64+
+ stride_emb_cache_l * 0
65+
+ stride_emb_cache_d * off_d,
6166
mask=off_d < hidden_size,
6267
other=0,
6368
)
64-
tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size)
69+
tl.store(
70+
Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb / tp_world_size, mask=off_d < hidden_size
71+
)
6572
return
6673

6774

@@ -70,35 +77,38 @@ def multimodal_emb(
7077
out: torch.Tensor,
7178
prompt_ids: torch.Tensor,
7279
text_weight_embs: torch.Tensor,
73-
img_embs: torch.Tensor,
80+
embed_cache: torch.Tensor,
7481
img_token_lens: torch.Tensor,
7582
img_start_token_ids: torch.Tensor,
76-
img_start_locs: torch.Tensor,
77-
tp_text_start_token_id,
78-
tp_text_end_token_id,
83+
img_start_locs_in_cache: torch.Tensor,
84+
tp_text_start_token_id: int,
85+
tp_text_end_token_id: int,
86+
tp_world_size: int,
7987
):
8088
total_len = prompt_ids.shape[0]
8189
BLOCK = triton.next_power_of_2(out.shape[1])
8290
# print(len(img_token_lens))
8391
grid = (total_len, len(img_token_lens) + 1)
8492
num_warps = 1
8593
_fwd_kernel[grid](
86-
prompt_ids,
87-
text_weight_embs,
88-
img_embs,
89-
out,
90-
img_token_lens,
91-
img_start_token_ids,
92-
img_start_locs,
93-
text_weight_embs.stride(0),
94-
text_weight_embs.stride(1),
95-
img_embs.stride(0),
96-
img_embs.stride(1),
97-
out.stride(0),
98-
out.stride(1),
99-
tp_text_start_token_id,
100-
tp_text_end_token_id,
94+
Prompt_ids=prompt_ids,
95+
Text_weight_embs=text_weight_embs,
96+
Embed_cache=embed_cache,
97+
Out=out,
98+
Img_token_lens=img_token_lens,
99+
Img_start_token_ids=img_start_token_ids,
100+
Img_start_locs_in_cache=img_start_locs_in_cache,
101+
stride_text_emb_s=text_weight_embs.stride(0),
102+
stride_text_emb_d=text_weight_embs.stride(1),
103+
stride_emb_cache_s=embed_cache.stride(0),
104+
stride_emb_cache_l=embed_cache.stride(1),
105+
stride_emb_cache_d=embed_cache.stride(2),
106+
stride_out_s=out.stride(0),
107+
stride_out_d=out.stride(1),
108+
tp_text_start_token_id=tp_text_start_token_id,
109+
tp_text_end_token_id=tp_text_end_token_id,
101110
hidden_size=out.shape[1],
111+
tp_world_size=float(tp_world_size),
102112
BLOCK_HIDDEN_DIM=BLOCK,
103113
num_warps=num_warps,
104114
num_stages=1,

lightllm/models/gemma3/layer_infer/pre_layer_infer.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,15 @@ def __init__(self, network_config, mode):
1414
return
1515

1616
def context_forward(self, input_ids, infer_state, layer_weight):
17-
img_weight = []
1817
img_start_token_ids = []
1918
img_token_lens = []
20-
img_start_loc = 0
21-
img_start_locs = []
19+
img_start_locs_in_cache = []
2220
device = layer_weight.wte_weight_.device
2321
dtype = layer_weight.wte_weight_.dtype
2422
hidden_size = layer_weight.wte_weight_.shape[1]
2523
weight_mask = torch.zeros((len(input_ids)), dtype=torch.float32, device=device)
2624

25+
# TODO
2726
scale = self.embed_scale
2827
for idx, input_id in enumerate(input_ids):
2928
if input_id == self.boi_token_index:
@@ -35,45 +34,40 @@ def context_forward(self, input_ids, infer_state, layer_weight):
3534
else:
3635
weight_mask[idx] = scale
3736

38-
infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids)
39-
4037
for batch_id, p in enumerate(infer_state.multimodal_params):
4138
for img in p["images"]:
4239
# skip the same image
43-
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
40+
if img["token_id"] in img_start_token_ids:
4441
continue
45-
# pull the img_embeds by uid from shm
46-
data = read_shm(get_shm_name_embed(img["uuid"]))
47-
img_weight.append(bytes2tensor(data).view(dtype).view(img["token_num"], -1).cuda(non_blocking=True))
4842
img_start_token_ids.append(img["token_id"])
4943
img_token_lens.append(img["token_num"])
50-
img_start_locs.append(img_start_loc)
51-
img_start_loc += img["token_num"]
44+
img_start_locs_in_cache.append(img["start_index_in_embed_cache"])
5245
out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device)
53-
if len(img_weight) > 0:
54-
img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype)
55-
else:
56-
img_weight = torch.empty((0, hidden_size), device=device, dtype=dtype)
57-
assert img_weight.shape[1] == hidden_size, (
46+
47+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
48+
49+
cpu_embed_cache_tensor = g_infer_context.cpu_embed_cache_client.cpu_embed_cache_tensor
50+
51+
assert cpu_embed_cache_tensor.shape[2] == hidden_size, (
5852
f"Dimension mismatch: text weight dimension is {hidden_size}, "
59-
f"but image weight dimension is {img_weight.shape[1]}"
53+
f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}"
6054
)
6155
# each tp will fill the img embeds, should divide by world_size
62-
img_weight = img_weight / self.tp_world_size_
6356
img_start_token_ids = torch.Tensor(img_start_token_ids).to(device=device, dtype=torch.long)
6457
img_token_lens = torch.Tensor(img_token_lens).to(device=device, dtype=torch.long)
65-
img_start_locs = torch.Tensor(img_start_locs).to(device=device, dtype=torch.long)
58+
img_start_locs_in_cache = torch.Tensor(img_start_locs_in_cache).to(device=device, dtype=torch.long)
6659

6760
multimodal_emb(
68-
out,
69-
input_ids,
70-
layer_weight.wte_weight_,
71-
img_weight,
72-
img_token_lens,
73-
img_start_token_ids,
74-
img_start_locs,
75-
self.vob_start_id_,
76-
self.vob_end_id_,
61+
out=out,
62+
prompt_ids=input_ids,
63+
text_weight_embs=layer_weight.wte_weight_,
64+
embed_cache=cpu_embed_cache_tensor,
65+
img_token_lens=img_token_lens,
66+
img_start_token_ids=img_start_token_ids,
67+
img_start_locs_in_cache=img_start_locs_in_cache,
68+
tp_text_start_token_id=self.vob_start_id_,
69+
tp_text_end_token_id=self.vob_end_id_,
70+
tp_world_size=self.tp_world_size_,
7771
)
7872
input_dtype = out.dtype
7973
if self.tp_world_size_ > 1:

lightllm/models/qwen3_vl/layer_infer/pre_layer_infer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ def context_forward(self, input_ids, infer_state: Qwen3VLInferStateInfo, layer_w
3434
dtype = layer_weight.wte_weight_.dtype
3535
hidden_size = layer_weight.wte_weight_.shape[1]
3636

37-
infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids)
38-
3937
for batch_id, p in enumerate(infer_state.multimodal_params):
4038
for img in p["images"] + p["audios"]:
4139
# skip the same image

lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,56 +32,47 @@ def __init__(self, network_config, mode):
3232
return
3333

3434
def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
35-
36-
img_weight = []
3735
img_start_token_ids = []
3836
img_token_lens = []
39-
img_start_loc = 0
40-
img_start_locs = []
41-
37+
img_start_locs_in_cache = []
4238
device = layer_weight.wte_weight_.device
4339
dtype = layer_weight.wte_weight_.dtype
4440
hidden_size = layer_weight.wte_weight_.shape[1]
4541

46-
infer_state.mark_multimodal_objs_for_prefill(input_ids=input_ids)
47-
4842
for batch_id, p in enumerate(infer_state.multimodal_params):
4943
for img in p["images"] + p["audios"]:
5044
# skip the same image
51-
if img["token_id"] in img_start_token_ids or img["_prefill_"] is False:
45+
if img["token_id"] in img_start_token_ids:
5246
continue
53-
# pull the img_embeds by uid from shm
54-
data = read_shm(get_shm_name_embed(img["uuid"]))
55-
img_weight.append(bytes2tensor(data).view(dtype).view(img["token_num"], -1).cuda(non_blocking=True))
5647
img_start_token_ids.append(img["token_id"])
5748
img_token_lens.append(img["token_num"])
58-
img_start_locs.append(img_start_loc)
59-
img_start_loc += img["token_num"]
49+
img_start_locs_in_cache.append(img["start_index_in_embed_cache"])
6050
out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device)
61-
if len(img_weight) > 0:
62-
img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype)
63-
else:
64-
img_weight = torch.empty((0, hidden_size), device=device, dtype=dtype)
65-
assert img_weight.shape[1] == hidden_size, (
51+
52+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
53+
54+
cpu_embed_cache_tensor = g_infer_context.cpu_embed_cache_client.cpu_embed_cache_tensor
55+
56+
assert cpu_embed_cache_tensor.shape[2] == hidden_size, (
6657
f"Dimension mismatch: text weight dimension is {hidden_size}, "
67-
f"but image weight dimension is {img_weight.shape[1]}"
58+
f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}"
6859
)
6960
# each tp will fill the img embeds, should divide by world_size
70-
img_weight = img_weight / self.tp_world_size_
7161
img_start_token_ids = torch.Tensor(img_start_token_ids).to(device=device, dtype=torch.long)
7262
img_token_lens = torch.Tensor(img_token_lens).to(device=device, dtype=torch.long)
73-
img_start_locs = torch.Tensor(img_start_locs).to(device=device, dtype=torch.long)
63+
img_start_locs_in_cache = torch.Tensor(img_start_locs_in_cache).to(device=device, dtype=torch.long)
7464

7565
multimodal_emb(
76-
out,
77-
input_ids,
78-
layer_weight.wte_weight_,
79-
img_weight,
80-
img_token_lens,
81-
img_start_token_ids,
82-
img_start_locs,
83-
self.vob_start_id_,
84-
self.vob_end_id_,
66+
out=out,
67+
prompt_ids=input_ids,
68+
text_weight_embs=layer_weight.wte_weight_,
69+
embed_cache=cpu_embed_cache_tensor,
70+
img_token_lens=img_token_lens,
71+
img_start_token_ids=img_start_token_ids,
72+
img_start_locs_in_cache=img_start_locs_in_cache,
73+
tp_text_start_token_id=self.vob_start_id_,
74+
tp_text_end_token_id=self.vob_end_id_,
75+
tp_world_size=self.tp_world_size_,
8576
)
8677
if self.tp_world_size_ > 1:
8778
all_reduce(out, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False)

0 commit comments

Comments
 (0)