Skip to content

Commit 6997a18

Browse files
committed
dp mtp eagle support
1 parent edc0ff6 commit 6997a18

File tree

6 files changed

+431
-204
lines changed

6 files changed

+431
-204
lines changed

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from lightllm.server.router.shm_reqs_io_buffer import ShmReqsIOBuffer
3232
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack
3333
from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel
34+
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
35+
from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token
3436

3537

3638
class ModeBackend:
@@ -574,6 +576,44 @@ def _gen_argmax_token_ids(self, model_output: ModelOutput):
574576
draft_next_token_ids_gpu = torch.argmax(probs, dim=-1)
575577
return draft_next_token_ids_gpu
576578

579+
def _sample_and_scatter_token(
580+
self,
581+
logits: torch.Tensor,
582+
b_req_idx: torch.Tensor,
583+
b_mtp_index: torch.Tensor,
584+
run_reqs: List[InferReq],
585+
is_prefill: bool,
586+
b_prefill_has_output_cpu: torch.Tensor = None,
587+
mask_func: Optional[Callable] = None,
588+
):
589+
590+
if mask_func is not None:
591+
mask_func(run_reqs, logits)
592+
593+
next_token_ids, next_token_logprobs = sample(logits, run_reqs, self.eos_id)
594+
b_has_out = None
595+
if is_prefill:
596+
b_has_out = g_pin_mem_manager.gen_from_list(
597+
key="b_has_out", data=b_prefill_has_output_cpu, dtype=torch.bool
598+
).cuda(non_blocking=True)
599+
600+
scatter_token(
601+
next_token_ids=next_token_ids,
602+
req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids,
603+
b_req_idx=b_req_idx,
604+
b_mtp_index=b_mtp_index,
605+
b_has_out=b_has_out,
606+
)
607+
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
608+
b_req_idx=b_req_idx,
609+
next_token_ids=next_token_ids,
610+
mask=b_has_out,
611+
)
612+
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
613+
next_token_ids, next_token_logprobs
614+
)
615+
return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu
616+
577617
def _dp_all_gather_prefill_and_decode_req_num(
578618
self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq]
579619
) -> Tuple[np.ndarray, np.ndarray]:

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,15 @@ def prefill_normal(
102102
prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal
103103
)
104104
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
105-
_, next_token_ids_cpu, next_token_logprobs_cpu, _ = self._main_model_forward(
106-
model_input, run_reqs, self.prefill_mask_func
105+
model_output = self.model.forward(model_input)
106+
next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
107+
logits=model_output.logits,
108+
b_req_idx=model_input.b_req_idx,
109+
b_mtp_index=model_input.b_mtp_index,
110+
run_reqs=run_reqs,
111+
is_prefill=True,
112+
b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu,
113+
mask_func=self.prefill_mask_func,
107114
)
108115
sync_event = torch.cuda.Event()
109116
sync_event.record()
@@ -133,8 +140,14 @@ def decode_normal(
133140
):
134141
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
135142
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
136-
_, next_token_ids_cpu, next_token_logprobs_cpu, _ = self._main_model_forward(
137-
model_input, run_reqs, self.decode_mask_func
143+
model_output = self.model.forward(model_input)
144+
next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
145+
logits=model_output.logits,
146+
b_req_idx=model_input.b_req_idx,
147+
b_mtp_index=model_input.b_mtp_index,
148+
run_reqs=run_reqs,
149+
is_prefill=False,
150+
mask_func=self.decode_mask_func,
138151
)
139152
sync_event = torch.cuda.Event()
140153
sync_event.record()
@@ -167,8 +180,15 @@ def prefill_mtp(
167180
prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal
168181
)
169182
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
170-
next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu, model_output = self._main_model_forward(
171-
model_input, run_reqs, self.prefill_mask_func
183+
model_output = self.model.forward(model_input)
184+
next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token(
185+
logits=model_output.logits,
186+
b_req_idx=model_input.b_req_idx,
187+
b_mtp_index=model_input.b_mtp_index,
188+
run_reqs=run_reqs,
189+
is_prefill=True,
190+
b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu,
191+
mask_func=self.prefill_mask_func,
172192
)
173193
# mtp kv fill
174194
self._draft_prefill_forward(model_input, model_output, self.prefill_mtp_step, next_token_ids)
@@ -201,7 +221,7 @@ def decode_mtp(
201221
decode_reqs: List[InferReq],
202222
):
203223
if self.is_mtp_eagle:
204-
draft_model_input, _, eagle_mem_indexes_cpu = prepare_eagle_decode_inputs(decode_reqs, self.mtp_step)
224+
draft_model_input, eagle_mem_indexes_cpu = prepare_eagle_decode_inputs(decode_reqs, self.mtp_step)
205225
self._decode_mtp_common(
206226
event_pack=event_pack,
207227
decode_reqs=decode_reqs,
@@ -218,39 +238,6 @@ def decode_mtp(
218238
)
219239
return
220240

221-
def _main_model_forward(
222-
self, model_input: ModelInput, run_reqs: List[InferReq], mask_func: Optional[Callable] = None
223-
):
224-
model_output = self.model.forward(model_input)
225-
logits = model_output.logits
226-
227-
if mask_func is not None:
228-
mask_func(run_reqs, logits)
229-
230-
next_token_ids, next_token_logprobs = sample(logits, run_reqs, self.eos_id)
231-
b_has_out = None
232-
if model_input.is_prefill:
233-
b_has_out = g_pin_mem_manager.gen_from_list(
234-
key="b_has_out", data=model_input.b_prefill_has_output_cpu, dtype=torch.bool
235-
).cuda(non_blocking=True)
236-
237-
scatter_token(
238-
next_token_ids=next_token_ids,
239-
req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids,
240-
b_req_idx=model_input.b_req_idx,
241-
b_mtp_index=model_input.b_mtp_index,
242-
b_has_out=b_has_out,
243-
)
244-
g_infer_context.req_sampling_manager.update_reqs_out_token_counter_gpu(
245-
b_req_idx=model_input.b_req_idx,
246-
next_token_ids=next_token_ids,
247-
mask=b_has_out,
248-
)
249-
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
250-
next_token_ids, next_token_logprobs
251-
)
252-
return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu, model_output
253-
254241
def _draft_prefill_forward(
255242
self, model_input: ModelInput, model_output: ModelOutput, mtp_step: int, next_token_ids: torch.Tensor
256243
):

0 commit comments

Comments
 (0)