Skip to content

Commit 8599b45

Browse files
committed
fix
1 parent dff8618 commit 8599b45

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,9 @@ def _token_forward(self, input_ids, infer_state: InferStateInfo):
463463

464464
@torch.no_grad()
465465
def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: ModelInput):
466+
model_input0.to_cuda()
467+
model_input1.to_cuda()
468+
466469
assert model_input0.mem_indexes.is_cuda
467470
assert model_input1.mem_indexes.is_cuda
468471
input_ids0, input_ids1 = model_input0.input_ids, model_input1.input_ids
@@ -500,6 +503,22 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod
500503

501504
@torch.no_grad()
502505
def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: ModelInput):
506+
model_input0.to_cuda()
507+
model_input1.to_cuda()
508+
509+
if model_input0.input_ids is None:
510+
model_input0.input_ids = gather_token(
511+
self.req_manager.req_sampling_params_manager.req_to_next_token_ids,
512+
model_input0.b_req_idx,
513+
model_input0.b_mtp_index,
514+
)
515+
if model_input1.input_ids is None:
516+
model_input1.input_ids = gather_token(
517+
self.req_manager.req_sampling_params_manager.req_to_next_token_ids,
518+
model_input1.b_req_idx,
519+
model_input1.b_mtp_index,
520+
)
521+
503522
assert model_input0.batch_size == model_input1.batch_size
504523
assert model_input0.mem_indexes.is_cuda
505524
assert model_input1.mem_indexes.is_cuda

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer
235235
b_has_out_cpu = (
236236
micro_input0.b_prefill_has_output_cpu[0:req_num0] + micro_input1.b_prefill_has_output_cpu[0:req_num1]
237237
)
238-
b_mtp_index = torch.cat(micro_input0.b_mtp_index[0:req_num0], micro_input1.b_mtp_index[0:req_num1])
239-
b_req_idx = torch.cat(micro_input0.b_req_idx[0:req_num0], micro_input1.b_req_idx[0:req_num1])
238+
b_mtp_index = torch.cat((micro_input0.b_mtp_index[0:req_num0], micro_input1.b_mtp_index[0:req_num1]), dim=0)
239+
b_req_idx = torch.cat((micro_input0.b_req_idx[0:req_num0], micro_input1.b_req_idx[0:req_num1]), dim=0)
240240

241241
if (req_num0 + req_num1) > 0:
242242

@@ -291,7 +291,7 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe
291291
micro_input1,
292292
run_reqs1,
293293
padded_req_num1,
294-
) = padded_overlap_prepare_decode_inputs(decode_reqs, is_multimodal=self.is_multimodal)
294+
) = padded_overlap_prepare_decode_inputs(req_objs=decode_reqs)
295295
micro_input0: ModelInput = micro_input0
296296
micro_input1: ModelInput = micro_input1
297297

@@ -305,8 +305,8 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe
305305

306306
logits[0:req_num0, :].copy_(logits0[0:req_num0, :], non_blocking=True)
307307
logits[req_num0 : (req_num0 + req_num1), :].copy_(logits1[0:req_num1, :], non_blocking=True)
308-
b_mtp_index = torch.cat(micro_input0.b_mtp_index[0:req_num0], micro_input1.b_mtp_index[0:req_num1])
309-
b_req_idx = torch.cat(micro_input0.b_req_idx[0:req_num0], micro_input1.b_req_idx[0:req_num1])
308+
b_mtp_index = torch.cat((micro_input0.b_mtp_index[0:req_num0], micro_input1.b_mtp_index[0:req_num1]), dim=0)
309+
b_req_idx = torch.cat((micro_input0.b_req_idx[0:req_num0], micro_input1.b_req_idx[0:req_num1]), dim=0)
310310

311311
run_reqs = run_reqs0 + run_reqs1
312312
if (req_num0 + req_num1) > 0:

0 commit comments

Comments
 (0)