Skip to content

Commit c42e528

Browse files
committed
fix
1 parent 098ae9d commit c42e528

File tree

1 file changed

+41
-33
lines changed

1 file changed

+41
-33
lines changed

lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,17 @@ def padded_prepare_prefill_inputs(
3333
b_seq_len = []
3434
batch_multimodal_params = []
3535
b_ready_cache_len = []
36+
b_mtp_index = []
37+
b_prefill_has_output = []
38+
3639
for req in req_objs:
3740

3841
run_reqs.append(req)
3942
batch_multimodal_params.append(req.multimodal_params)
4043
b_req_idx.append(req.req_idx)
4144

4245
input_token_ids = req.get_chuncked_input_token_ids()
46+
b_prefill_has_output.append(False if len(input_token_ids) < req.get_cur_total_len() else True)
4347
seq_len = len(input_token_ids)
4448
input_token_len = seq_len - req.cur_kv_len
4549
input_id = input_token_ids[req.cur_kv_len :]
@@ -49,27 +53,32 @@ def padded_prepare_prefill_inputs(
4953
total_token_num += seq_len
5054
max_len_in_batch = max(max_len_in_batch, input_token_len)
5155
b_ready_cache_len.append(req.cur_kv_len)
56+
b_mtp_index.append(0)
5257

5358
# padding fake req for prefill
5459
for _ in range(padded_req_num):
5560
input_ids.append([1])
5661
b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID)
5762
b_seq_len.append(1)
63+
b_mtp_index.append(0)
64+
b_prefill_has_output.append(False)
5865
b_ready_cache_len.append(0)
5966
total_token_num += 1
6067
max_len_in_batch = max(max_len_in_batch, 1)
6168

6269
input_ids = np.concatenate(input_ids, dtype=np.int64)
63-
input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda")
64-
b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cuda")
65-
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cuda")
66-
b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cuda")
70+
71+
input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cpu")
72+
b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu")
73+
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
74+
b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu")
75+
b_ready_cache_len = torch.tensor(b_ready_cache_len, dtype=torch.int32, device="cpu")
6776

6877
# dynamic prompt cache 准备 token
6978
g_infer_state_lock.acquire()
7079
if g_infer_context.radix_cache is not None:
7180
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num)
72-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda()
81+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num)
7382
g_infer_state_lock.release()
7483

7584
if padded_req_num > 0:
@@ -85,11 +94,13 @@ def padded_prepare_prefill_inputs(
8594
total_token_num=total_token_num,
8695
max_len_in_batch=max_len_in_batch,
8796
input_ids=input_ids,
88-
mem_indexes=mem_indexes,
97+
mem_indexes_cpu=mem_indexes,
8998
b_req_idx=b_req_idx,
99+
b_mtp_index=b_mtp_index,
90100
b_seq_len=b_seq_len,
91101
b_ready_cache_len=b_ready_cache_len,
92102
is_prefill=True,
103+
b_prefill_has_output_cpu=b_prefill_has_output,
93104
)
94105
if is_multimodal:
95106
model_input.multimodal_params = batch_multimodal_params
@@ -98,64 +109,62 @@ def padded_prepare_prefill_inputs(
98109

99110

100111
def padded_prepare_decode_inputs(
101-
req_objs: List[InferReq], dest_batch_size: Optional[int] = None, is_multimodal=False
112+
req_objs: List[InferReq], dest_batch_size: Optional[int] = None
102113
) -> Tuple[ModelInput, List[InferReq], int]:
114+
mtp_step_num = get_env_start_args().mtp_step
103115
run_reqs = []
104116
total_token_num = 0
105117
max_len_in_batch = 0
106-
input_ids = []
107118
b_req_idx = []
119+
b_mtp_index = []
108120
b_seq_len = []
109-
110121
for req in req_objs:
111122
run_reqs.append(req)
112123
b_req_idx.append(req.req_idx)
113-
input_token_ids = req.get_input_token_ids()
114-
input_id = input_token_ids[-1]
115-
seq_len = len(input_token_ids)
124+
seq_len = req.get_cur_total_len()
116125
assert req.cur_kv_len == seq_len - 1
117126
b_seq_len.append(seq_len)
118-
input_ids.append(input_id)
119127
total_token_num += seq_len
120128
max_len_in_batch = max(max_len_in_batch, seq_len)
129+
b_mtp_index.append(0)
121130
# process the draft tokens.
122-
for step in range(len(req.mtp_gen_token_ids)):
131+
for step in range(req.mtp_step):
123132
run_reqs.append(req)
124133
b_req_idx.append(req.req_idx)
125134
seq_len += 1
126135
b_seq_len.append(seq_len)
127-
input_ids.append(req.mtp_gen_token_ids[step])
128136
total_token_num += seq_len
129137
max_len_in_batch = max(max_len_in_batch, seq_len)
138+
b_mtp_index.append(step + 1)
130139

131140
if dest_batch_size is None:
132141
if len(run_reqs) == 0:
133142
dest_batch_size = 1
134143
else:
135-
dest_batch_size = len(run_reqs)
144+
dest_batch_size = len(run_reqs) * (1 + mtp_step_num)
136145
else:
137-
assert len(run_reqs) <= dest_batch_size
146+
assert len(run_reqs) * (1 + mtp_step_num) <= dest_batch_size
138147

139-
padded_req_num = dest_batch_size - len(run_reqs)
148+
padded_req_num = dest_batch_size - len(run_reqs) * (1 + mtp_step_num)
140149

141150
# padding fake req for decode
142151
for _ in range(padded_req_num):
143-
input_ids.append(1)
144152
seq_len = 2
145153
b_req_idx.append(g_infer_context.req_manager.HOLD_REQUEST_ID)
146154
b_seq_len.append(seq_len)
155+
b_mtp_index.append(0)
147156
total_token_num += seq_len
148157
max_len_in_batch = max(max_len_in_batch, seq_len)
149158

150-
input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda")
151-
b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cuda")
152-
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cuda")
159+
b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu")
160+
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
161+
b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu")
153162

154163
# dynamic prompt cache 准备 token
155164
g_infer_state_lock.acquire()
156165
if g_infer_context.radix_cache is not None:
157-
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num)
158-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num).cuda()
166+
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num)
167+
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0] - padded_req_num)
159168
g_infer_state_lock.release()
160169

161170
if padded_req_num > 0:
@@ -170,36 +179,35 @@ def padded_prepare_decode_inputs(
170179
batch_size=b_seq_len.shape[0],
171180
total_token_num=total_token_num,
172181
max_len_in_batch=max_len_in_batch,
173-
input_ids=input_ids,
174-
mem_indexes=mem_indexes,
182+
input_ids=None,
183+
mem_indexes_cpu=mem_indexes,
175184
b_req_idx=b_req_idx,
185+
b_mtp_index=b_mtp_index,
176186
b_seq_len=b_seq_len,
177187
is_prefill=False,
178188
)
179189
return model_input, run_reqs, padded_req_num
180190

181191

182-
def padded_overlap_prepare_decode_inputs(req_objs: List[InferReq], is_multimodal=False):
192+
def padded_overlap_prepare_decode_inputs(req_objs: List[InferReq]):
183193
split_req_bound = triton.cdiv(len(req_objs), 2)
184194
req_objs_0 = req_objs[0:split_req_bound]
185195
req_objs_1 = req_objs[split_req_bound:]
186196

187197
enable_mtp = get_env_start_args().mtp_mode is not None
188198
if enable_mtp:
189199
micro_batch_size = max(
190-
sum([len(req.mtp_gen_token_ids) + 1 for req in req_objs_0]),
191-
sum([len(req.mtp_gen_token_ids) + 1 for req in req_objs_1]),
200+
sum([req.mtp_step + 1 for req in req_objs_0]),
201+
sum([req.mtp_step + 1 for req in req_objs_1]),
192202
)
193203
else:
194204
micro_batch_size = triton.cdiv(len(req_objs), 2)
195205

196206
micro_batch_size = max(1, micro_batch_size)
197207

198-
micro_input, run_reqs, padded_req_num = padded_prepare_decode_inputs(
199-
req_objs_0, dest_batch_size=micro_batch_size, is_multimodal=is_multimodal
200-
)
208+
micro_input, run_reqs, padded_req_num = padded_prepare_decode_inputs(req_objs_0, dest_batch_size=micro_batch_size)
201209
micro_input1, run_reqs1, padded_req_num1 = padded_prepare_decode_inputs(
202-
req_objs_1, dest_batch_size=micro_batch_size, is_multimodal=is_multimodal
210+
req_objs_1, dest_batch_size=micro_batch_size
203211
)
204212
return micro_input, run_reqs, padded_req_num, micro_input1, run_reqs1, padded_req_num1
205213

0 commit comments

Comments
 (0)