Skip to content

Commit 149c72d

Browse files
committed
fix overlap
1 parent c02732f commit 149c72d

File tree

4 files changed

+12
-13
lines changed

4 files changed

+12
-13
lines changed

lightllm/models/deepseek2/flashattention_infer_struct.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,4 @@ def init_some_extra_state(self, model, model_input: ModelInput):
5353
self.page_table[:, :max_seq_len_k].copy_(
5454
model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k]
5555
)
56-
self.page_table[:, max_seq_len_k:].fill_(0)
5756
return

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def decode_mtp(
265265
decode_reqs: List[InferReq],
266266
):
267267
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
268-
b_mtp_index_cpu = model_input.b_mtp_index
268+
b_mtp_index_cpu = model_input.b_mtp_index_cpu
269269
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
270270
model_output = self.model.forward(model_input)
271271
all_next_token_ids = []

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]
449449

450450
def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]):
451451
model_input, run_reqs, padded_req_num = padded_prepare_decode_inputs(decode_reqs)
452-
b_mtp_index_cpu = model_input.b_mtp_index
452+
b_mtp_index_cpu = model_input.b_mtp_index_cpu
453453
req_num = len(run_reqs)
454454

455455
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
@@ -680,8 +680,8 @@ def decode_overlap_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[Inf
680680
) = padded_overlap_prepare_decode_inputs(decode_reqs)
681681
req_num0, req_num1 = len(run_reqs0), len(run_reqs1)
682682
all_next_token_ids = []
683-
b_mtp_index_cpu0 = micro_input0.b_mtp_index
684-
b_mtp_index_cpu1 = micro_input1.b_mtp_index
683+
b_mtp_index_cpu0 = micro_input0.b_mtp_index_cpu
684+
b_mtp_index_cpu1 = micro_input1.b_mtp_index_cpu
685685
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
686686

687687
model_output0, model_output1 = self.model.microbatch_overlap_decode(micro_input0, micro_input1)

lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,12 @@ def padded_prepare_prefill_inputs(
9393
batch_size=b_seq_len.shape[0],
9494
total_token_num=total_token_num,
9595
max_len_in_batch=max_len_in_batch,
96-
input_ids=input_ids,
96+
input_ids_cpu=input_ids,
9797
mem_indexes_cpu=mem_indexes,
98-
b_req_idx=b_req_idx,
99-
b_mtp_index=b_mtp_index,
100-
b_seq_len=b_seq_len,
101-
b_ready_cache_len=b_ready_cache_len,
98+
b_req_idx_cpu=b_req_idx,
99+
b_mtp_index_cpu=b_mtp_index,
100+
b_seq_len_cpu=b_seq_len,
101+
b_ready_cache_len_cpu=b_ready_cache_len,
102102
is_prefill=True,
103103
b_prefill_has_output_cpu=b_prefill_has_output,
104104
)
@@ -180,9 +180,9 @@ def padded_prepare_decode_inputs(
180180
max_len_in_batch=max_len_in_batch,
181181
input_ids=None,
182182
mem_indexes_cpu=mem_indexes,
183-
b_req_idx=b_req_idx,
184-
b_mtp_index=b_mtp_index,
185-
b_seq_len=b_seq_len,
183+
b_req_idx_cpu=b_req_idx,
184+
b_mtp_index_cpu=b_mtp_index,
185+
b_seq_len_cpu=b_seq_len,
186186
is_prefill=False,
187187
)
188188
return model_input, run_reqs, padded_req_num

0 commit comments

Comments
 (0)