Skip to content

Commit 089e617

Browse files
committed
fix
1 parent ddf151d commit 089e617

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,8 @@ def decode_mtp(
253253

254254
# 处理需要释放的内存索引
255255
need_free_mem_indexes = model_input.mem_indexes_cpu[verify_info["accepted_index_cpu"] == 0]
256-
need_free_mem_indexes = torch.cat([need_free_mem_indexes, additional_mem_indexes_cpu], dim=0)
256+
if additional_mem_indexes_cpu is not None:
257+
need_free_mem_indexes = torch.cat([need_free_mem_indexes, additional_mem_indexes_cpu], dim=0)
257258

258259
self._update_mtp_accept_ratio(decode_reqs=decode_reqs, mtp_accept_len_cpu=verify_info["mtp_accept_len_cpu"])
259260
select_mask = torch.tensor(verify_info["accepted_index_cpu"], dtype=torch.bool, device="cpu")

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self) -> None:
3737
# 在 mtp 模式下切换绑定的prefill 和 decode 函数
3838
if get_env_start_args().mtp_mode:
3939
self.is_mtp_eagle = get_env_start_args().mtp_mode == "deepseekv3_eagle"
40-
self.prefill_mtp_step = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step
40+
self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step
4141
if self.enable_prefill_microbatch_overlap:
4242
self.prefill = self.prefill_overlap_mtp
4343
else:
@@ -360,7 +360,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]
360360
self._draft_prefill_forward(
361361
model_input=model_input,
362362
model_output=model_output,
363-
mtp_step=self.prefill_mtp_step,
363+
mtp_step=self.num_mtp_models,
364364
next_token_ids=draft_next_token_ids_gpu,
365365
)
366366
sync_event = torch.cuda.Event()
@@ -596,7 +596,7 @@ def prefill_overlap_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[I
596596

597597
draft_model_output0, draft_model_output1 = model_output0, model_output1
598598

599-
for draft_model_idx in range(self.prefill_mtp_step):
599+
for draft_model_idx in range(self.num_mtp_models):
600600

601601
draft_model_input0 = prepare_mtp_prefill_inputs(
602602
model_input=draft_model_input0,

lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ def padded_prepare_decode_inputs(
133133
b_req_idx = []
134134
b_mtp_index = []
135135
b_seq_len = []
136+
max_q_seq_len = 0
137+
max_kv_seq_len = 0
136138
for req in req_objs:
137139
run_reqs.append(req)
138140
b_req_idx.append(req.req_idx)
@@ -141,6 +143,8 @@ def padded_prepare_decode_inputs(
141143
b_seq_len.append(seq_len)
142144
total_token_num += seq_len
143145
max_len_in_batch = max(max_len_in_batch, seq_len)
146+
max_q_seq_len = max(max_q_seq_len, req.mtp_step + 1)
147+
max_kv_seq_len = max(max_kv_seq_len, seq_len)
144148
b_mtp_index.append(0)
145149
# process the draft tokens.
146150
for step in range(req.mtp_step):
@@ -150,6 +154,7 @@ def padded_prepare_decode_inputs(
150154
b_seq_len.append(seq_len)
151155
total_token_num += seq_len
152156
max_len_in_batch = max(max_len_in_batch, seq_len)
157+
max_kv_seq_len = max(max_kv_seq_len, seq_len)
153158
b_mtp_index.append(step + 1)
154159

155160
if dest_batch_size is None:
@@ -170,6 +175,8 @@ def padded_prepare_decode_inputs(
170175
b_mtp_index.append(0)
171176
total_token_num += seq_len
172177
max_len_in_batch = max(max_len_in_batch, seq_len)
178+
max_q_seq_len = max(max_q_seq_len, 1)
179+
max_kv_seq_len = max(max_kv_seq_len, seq_len)
173180

174181
b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu")
175182
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
@@ -194,6 +201,8 @@ def padded_prepare_decode_inputs(
194201
batch_size=b_seq_len.shape[0],
195202
total_token_num=total_token_num,
196203
max_len_in_batch=max_len_in_batch,
204+
max_q_seq_len=max_q_seq_len,
205+
max_kv_seq_len=max_kv_seq_len,
197206
input_ids=None,
198207
mem_indexes_cpu=mem_indexes,
199208
b_req_idx=b_req_idx,

0 commit comments

Comments
 (0)