Skip to content

Commit 8a8059d

Browse files
committed
update
1 parent c0b2ba1 commit 8a8059d

File tree

1 file changed

+27
-119
lines changed
  • lightllm/server/router/model_infer/mode_backend/dp_backend

1 file changed

+27
-119
lines changed

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 27 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self) -> None:
4545
if self.enable_decode_microbatch_overlap:
4646
self.decode = self.decode_overlap_mtp
4747
else:
48-
self.decode = self.decode_mtp_eagle
48+
self.decode = self.decode_mtp
4949
else:
5050
if self.enable_prefill_microbatch_overlap:
5151
self.prefill = self.prefill_overlap
@@ -396,7 +396,12 @@ def prefill_mtp(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]
396396
def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]):
397397
model_input, run_reqs, padded_req_num = padded_prepare_decode_inputs(decode_reqs)
398398
b_mtp_index_cpu = model_input.b_mtp_index
399+
eagle_mem_indexes_cpu = None
399400
req_num = len(run_reqs)
401+
if self.is_mtp_eagle:
402+
draft_model_input, eagle_mem_indexes_cpu = padded_prepare_eagle_decode_inputs(
403+
decode_reqs, padded_req_num=padded_req_num, mtp_step=self.mtp_step
404+
)
400405

401406
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
402407
model_output = self.model.forward(model_input)
@@ -436,16 +441,25 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]):
436441

437442
verify_event = torch.cuda.Event()
438443
verify_event.record()
439-
440-
self._draft_decode_vanilla(
441-
model_input=model_input,
442-
model_output=model_output,
443-
draft_next_token_ids_gpu=draft_next_token_ids_gpu,
444-
b_req_mtp_start_loc=b_req_mtp_start_loc[:real_req_num],
445-
mtp_accept_len=mtp_accept_len[:real_req_num],
446-
req_num=req_num,
447-
)
448-
444+
if self.is_mtp_eagle:
445+
self._draft_decode_eagle(
446+
model_input=model_input,
447+
model_output=model_output,
448+
draft_next_token_ids_gpu=draft_next_token_ids_gpu,
449+
mtp_accept_len=mtp_accept_len,
450+
eagle_mem_indexes_cpu=eagle_mem_indexes_cpu,
451+
draft_model_input=draft_model_input,
452+
padded_req_num=padded_req_num,
453+
)
454+
else:
455+
self._draft_decode_vanilla(
456+
model_input=model_input,
457+
model_output=model_output,
458+
draft_next_token_ids_gpu=draft_next_token_ids_gpu,
459+
b_req_mtp_start_loc=b_req_mtp_start_loc[:real_req_num],
460+
mtp_accept_len=mtp_accept_len[:real_req_num],
461+
req_num=req_num,
462+
)
449463
if req_num > 0:
450464
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
451465
b_req_idx=model_input.b_req_idx[:req_num],
@@ -468,6 +482,8 @@ def decode_mtp(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]):
468482
event_pack.notify_forward_and_wait_post_handle()
469483
sync_event.synchronize()
470484
need_free_mem_indexes = model_input.mem_indexes_cpu[0:req_num][accepted_index_cpu == 0]
485+
if self.is_mtp_eagle:
486+
need_free_mem_indexes = torch.cat([need_free_mem_indexes, eagle_mem_indexes_cpu], dim=0)
471487

472488
self._update_mtp_accept_ratio(decode_reqs=decode_reqs, mtp_accept_len_cpu=mtp_accept_len_cpu)
473489
select_mask = torch.tensor(accepted_index_cpu, dtype=torch.bool, device="cpu")
@@ -528,122 +544,14 @@ def _draft_decode_vanilla(
528544
)
529545
return all_next_token_ids
530546

531-
def decode_mtp_eagle(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]):
532-
model_input, run_reqs, padded_req_num = padded_prepare_decode_inputs(decode_reqs)
533-
draft_model_input, eagle_mem_indexes_cpu = padded_prepare_eagle_decode_inputs(
534-
decode_reqs, padded_req_num=padded_req_num, mtp_step=self.mtp_step
535-
)
536-
b_mtp_index_cpu = model_input.b_mtp_index
537-
req_num = len(run_reqs)
538-
539-
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
540-
model_output = self.model.forward(model_input)
541-
draft_next_token_ids_gpu = torch.zeros((model_input.batch_size), dtype=torch.int64, device="cuda")
542-
543-
if req_num > 0:
544-
logits = model_output.logits[0:req_num, :]
545-
next_token_ids, next_token_logprobs = sample(logits, run_reqs, self.eos_id)
546-
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
547-
next_token_ids, next_token_logprobs
548-
)
549-
draft_next_token_ids_gpu[0:req_num].copy_(next_token_ids)
550-
551-
# verify the next_token_ids
552-
b_req_mtp_start_loc = [index for index, mtp_index in enumerate(b_mtp_index_cpu) if mtp_index == 0]
553-
b_req_mtp_start_loc = g_pin_mem_manager.gen_from_list(
554-
key="b_req_mtp_start_loc",
555-
data=b_req_mtp_start_loc,
556-
dtype=torch.int32,
557-
).cuda(non_blocking=True)
558-
# 真实的请求数,不包含mtp 扩充的部分
559-
real_req_num = b_req_mtp_start_loc.shape[0] - padded_req_num
560-
561-
mtp_accept_len, accepted_index = self._verify_mtp_v2(
562-
new_next_token_ids=draft_next_token_ids_gpu,
563-
b_req_idx=model_input.b_req_idx,
564-
b_req_mtp_start_loc=b_req_mtp_start_loc,
565-
)
566-
accepted_index_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor(
567-
key="accepted_index",
568-
gpu_tensor=accepted_index[:req_num],
569-
)
570-
mtp_accept_len_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor(
571-
key="mtp_accept_len",
572-
gpu_tensor=mtp_accept_len[:real_req_num],
573-
)
574-
575-
verify_event = torch.cuda.Event()
576-
verify_event.record()
577-
578-
self._draft_decode_eagle(
579-
model_input=model_input,
580-
model_output=model_output,
581-
draft_next_token_ids_gpu=draft_next_token_ids_gpu,
582-
b_req_mtp_start_loc=b_req_mtp_start_loc,
583-
mtp_accept_len=mtp_accept_len,
584-
eagle_mem_indexes_cpu=eagle_mem_indexes_cpu,
585-
draft_model_input=draft_model_input,
586-
req_num=req_num,
587-
padded_req_num=padded_req_num,
588-
)
589-
590-
if req_num > 0:
591-
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
592-
b_req_idx=model_input.b_req_idx[:req_num],
593-
next_token_ids=next_token_ids,
594-
mask=accepted_index[:req_num] == 1,
595-
)
596-
597-
sync_event = torch.cuda.Event()
598-
sync_event.record()
599-
600-
if req_num > 0:
601-
# 第二阶段
602-
accepted_index_cpu = accepted_index_cpu[:req_num]
603-
event_pack.notify_post_handle_and_wait_pre_post_handle()
604-
verify_event.synchronize()
605-
verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu[i] == 1]
606-
update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False)
607-
608-
# 第三阶段
609-
event_pack.notify_forward_and_wait_post_handle()
610-
sync_event.synchronize()
611-
need_free_mem_indexes = torch.cat(
612-
[model_input.mem_indexes_cpu[0:req_num][accepted_index_cpu == 0], eagle_mem_indexes_cpu], dim=0
613-
)
614-
615-
self._update_mtp_accept_ratio(decode_reqs=decode_reqs, mtp_accept_len_cpu=mtp_accept_len_cpu)
616-
select_mask = torch.tensor(accepted_index_cpu, dtype=torch.bool, device="cpu")
617-
self._post_handle(
618-
run_reqs=verify_ok_reqs,
619-
next_token_ids=next_token_ids_cpu[select_mask],
620-
next_token_logprobs=next_token_logprobs_cpu[select_mask],
621-
run_reqs_update_packs=update_packs,
622-
extra_post_req_handle_func=self.extra_post_req_handle_func,
623-
)
624-
if len(need_free_mem_indexes) > 0:
625-
g_infer_state_lock.acquire()
626-
g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes)
627-
g_infer_state_lock.release()
628-
629-
# 第四阶段
630-
event_pack.notify_pre_post_handle()
631-
else:
632-
event_pack.notify_post_handle_and_wait_pre_post_handle()
633-
event_pack.notify_forward_and_wait_post_handle()
634-
event_pack.notify_pre_post_handle()
635-
return
636-
637547
def _draft_decode_eagle(
638548
self,
639549
model_input: ModelInput,
640550
model_output: ModelOutput,
641551
draft_next_token_ids_gpu: torch.Tensor,
642-
b_req_mtp_start_loc: torch.Tensor,
643552
mtp_accept_len: torch.Tensor,
644553
eagle_mem_indexes_cpu: torch.Tensor,
645554
draft_model_input: ModelInput,
646-
req_num: int,
647555
padded_req_num: int,
648556
):
649557
all_next_token_ids = []

0 commit comments

Comments
 (0)