Skip to content

Commit ac47e1f

Browse files
committed
mtp overlap (draft)
1 parent 8705d0a commit ac47e1f

File tree

2 files changed

+108
-61
lines changed

2 files changed

+108
-61
lines changed

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from lightllm.server.core.objs import ShmReqManager, StartArgs
2121
from lightllm.server.core.objs.io_objs import AbortedReqCmd
2222
from lightllm.server.router.model_infer.infer_batch import g_infer_context
23+
from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager
2324
from lightllm.utils.dist_utils import get_global_rank, get_global_world_size, get_dp_size
2425
from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp
2526
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size
@@ -250,6 +251,21 @@ def init_mtp_draft_model(self, main_kvargs: dict):
250251
self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}")
251252
return
252253

254+
def _save_next_token_ids_and_logprobs(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor):
255+
"""
256+
这个函数会把next token id和logprobs保存到pinned memory中,并返回一个同步事件。
257+
这样可以保障post_handle 函数可以读取到正常的输出结果。
258+
"""
259+
next_token_ids_cpu = g_pin_mem_manager.alloc_pin_tensor(
260+
"next_token_ids", next_token_ids.shape[0], next_token_ids.dtype
261+
)
262+
next_token_logprobs_cpu = g_pin_mem_manager.alloc_pin_tensor(
263+
"next_token_logprobs", next_token_logprobs.shape[0], next_token_logprobs.dtype
264+
)
265+
next_token_ids_cpu.copy_(next_token_ids, non_blocking=True)
266+
next_token_logprobs_cpu.copy_(next_token_logprobs, non_blocking=True)
267+
return next_token_ids_cpu, next_token_logprobs_cpu
268+
253269
def _try_read_new_reqs(self):
254270
if self.is_multinode_tp:
255271
self._try_read_new_reqs_multinode_tp()
@@ -568,7 +584,7 @@ def _gen_argmax_token_ids(self, model_output: ModelOutput):
568584
logits = model_output.logits
569585
probs = torch.softmax(logits, dim=-1)
570586
draft_next_token_ids_gpu = torch.argmax(probs, dim=-1)
571-
return draft_next_token_ids_gpu, draft_next_token_ids_gpu.detach().cpu().numpy()
587+
return draft_next_token_ids_gpu
572588

573589
def _update_reqs_mtp_gen_token_ids(self, reqs: List[InferReq], mtp_draft_next_token_ids: np.ndarray):
574590
for req, token_id in zip(reqs, mtp_draft_next_token_ids):

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py

Lines changed: 91 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,9 @@ def prefill_normal(
104104
model_input.b_req_idx,
105105
model_input.b_mtp_index,
106106
)
107-
next_token_ids_cpu = g_pin_mem_manager.alloc_pin_tensor(
108-
"next_token_ids", next_token_ids.shape[0], next_token_ids.dtype
107+
next_token_ids_cpu, next_token_logprobs_cpu = self._save_next_token_ids_and_logprobs(
108+
next_token_ids, next_token_logprobs
109109
)
110-
next_token_logprobs_cpu = g_pin_mem_manager.alloc_pin_tensor(
111-
"next_token_logprobs", next_token_logprobs.shape[0], next_token_logprobs.dtype
112-
)
113-
next_token_ids_cpu.copy_(next_token_ids, non_blocking=True)
114-
next_token_logprobs_cpu.copy_(next_token_logprobs, non_blocking=True)
115110
sync_event = torch.cuda.Event()
116111
sync_event.record()
117112

@@ -152,14 +147,9 @@ def decode_normal(
152147
model_input.b_req_idx,
153148
model_input.b_mtp_index,
154149
)
155-
next_token_ids_cpu = g_pin_mem_manager.alloc_pin_tensor(
156-
"next_token_ids", next_token_ids.shape[0], next_token_ids.dtype
157-
)
158-
next_token_logprobs_cpu = g_pin_mem_manager.alloc_pin_tensor(
159-
"next_token_logprobs", next_token_logprobs.shape[0], next_token_logprobs.dtype
150+
next_token_ids_cpu, next_token_logprobs_cpu = self._save_next_token_ids_and_logprobs(
151+
next_token_ids, next_token_logprobs
160152
)
161-
next_token_ids_cpu.copy_(next_token_ids, non_blocking=True)
162-
next_token_logprobs_cpu.copy_(next_token_logprobs, non_blocking=True)
163153
sync_event = torch.cuda.Event()
164154
sync_event.record()
165155

@@ -190,13 +180,44 @@ def prefill_mtp(
190180
model_input, run_reqs = prepare_prefill_inputs(
191181
prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal
192182
)
193-
model_output = self.model.forward(model_input)
183+
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
184+
model_output = self.model.forward(model_input)
185+
next_token_ids, next_token_logprobs = sample(model_output.logits, run_reqs, self.eos_id)
186+
187+
scatter_token(
188+
next_token_ids,
189+
self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids,
190+
model_input.b_req_idx,
191+
model_input.b_mtp_index,
192+
)
193+
next_token_ids_cpu, next_token_logprobs_cpu = self._save_next_token_ids_and_logprobs(
194+
next_token_ids, next_token_logprobs
195+
)
196+
# mtp kv fill
197+
draft_next_token_ids_gpu = next_token_ids
198+
draft_model_output = model_output
199+
draft_model_input = model_input
200+
# spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。
201+
for draft_model_idx in range(self.mtp_step):
202+
draft_model_input = prepare_mtp_prefill_inputs(
203+
model_input=draft_model_input,
204+
b_next_token_ids=draft_next_token_ids_gpu,
205+
deepseekv3_mtp_draft_input_hiddens=draft_model_output.deepseekv3_mtp_main_output_hiddens,
206+
)
207+
draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input)
208+
draft_next_token_ids_gpu = self._gen_argmax_token_ids(draft_model_output)
194209

195-
next_token_ids_gpu, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id)
196-
next_token_ids_cpu = next_token_ids_gpu.detach().cpu().numpy()
197-
next_token_logprobs_cpu = torch.log(next_token_probs).detach().cpu().numpy()
210+
sync_event = torch.cuda.Event()
211+
sync_event.record()
198212

213+
# 第二阶段
214+
event_pack.notify_post_handle_and_wait_pre_post_handle()
199215
update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=not self.disable_chunked_prefill)
216+
217+
# 第三阶段
218+
event_pack.notify_forward_and_wait_post_handle()
219+
sync_event.synchronize()
220+
200221
self._post_handle(
201222
run_reqs=run_reqs,
202223
next_token_ids=next_token_ids_cpu,
@@ -205,20 +226,8 @@ def prefill_mtp(
205226
extra_post_req_handle_func=self.extra_post_req_handle_func,
206227
)
207228

208-
# mtp kv fill
209-
draft_next_token_ids_gpu = next_token_ids_gpu
210-
draft_model_output = model_output
211-
draft_model_input = model_input
212-
# spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。
213-
for draft_model_idx in range(self.mtp_step):
214-
draft_model_input = prepare_mtp_prefill_inputs(
215-
model_input=draft_model_input,
216-
b_next_token_ids=draft_next_token_ids_gpu,
217-
deepseekv3_mtp_draft_input_hiddens=draft_model_output.deepseekv3_mtp_main_output_hiddens,
218-
)
219-
220-
draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input)
221-
draft_next_token_ids_gpu, draft_next_token_ids_cpu = self._gen_argmax_token_ids(draft_model_output)
229+
# 第四阶段
230+
event_pack.notify_pre_post_handle()
222231
return
223232

224233
def decode_mtp(
@@ -227,47 +236,69 @@ def decode_mtp(
227236
decode_reqs: List[InferReq],
228237
):
229238
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
230-
model_output = self.model.forward(model_input)
239+
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
240+
model_output = self.model.forward(model_input)
231241

232-
next_token_ids_gpu, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id)
233-
next_token_ids_cpu = next_token_ids_gpu.detach().cpu().numpy()
234-
next_token_logprobs_cpu = torch.log(next_token_probs).detach().cpu().numpy()
242+
next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id)
243+
scatter_token(
244+
next_token_ids,
245+
self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids,
246+
model_input.b_req_idx,
247+
model_input.b_mtp_index,
248+
)
249+
next_token_ids_cpu, next_token_logprobs_cpu = self._save_next_token_ids_and_logprobs(
250+
next_token_ids, next_token_probs
251+
)
235252

236-
# verify
237-
mem_indexes_cpu = model_input.mem_indexes.detach().cpu().numpy()
238-
verify_ok_reqs, verify_ok_req_indexes, verify_ok_req_last_indexes, need_free_mem_indexes = self._verify_mtp(
239-
run_reqs, next_token_ids_cpu, mem_indexes_cpu
240-
)
253+
# verify
254+
mem_indexes_cpu = model_input.mem_indexes.detach().cpu().numpy()
255+
verify_ok_reqs, verify_ok_req_indexes, verify_ok_req_last_indexes, need_free_mem_indexes = self._verify_mtp(
256+
run_reqs, next_token_ids_cpu, mem_indexes_cpu
257+
)
258+
259+
# share some inference info with the main model
260+
draft_model_input = model_input
261+
draft_model_output = model_output
262+
draft_next_token_ids = next_token_ids
263+
# process the draft model output
264+
for draft_model_idx in range(self.mtp_step):
265+
266+
draft_model_input.input_ids = draft_next_token_ids
267+
draft_model_input.deepseekv3_mtp_draft_input_hiddens = (
268+
draft_model_output.deepseekv3_mtp_main_output_hiddens
269+
)
270+
# spec decode: MTP
271+
draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input)
272+
draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output)
273+
274+
unique_reqs = [run_reqs[index] for index in verify_ok_req_last_indexes]
275+
draft_next_token_ids_cpu = draft_next_token_ids.detach().cpu().numpy()
276+
self._update_reqs_mtp_gen_token_ids(
277+
reqs=unique_reqs, mtp_draft_next_token_ids=draft_next_token_ids_cpu[verify_ok_req_last_indexes]
278+
)
279+
sync_event = torch.cuda.Event()
280+
sync_event.record()
241281

282+
# 第二阶段
283+
event_pack.notify_post_handle_and_wait_pre_post_handle()
242284
update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False)
285+
286+
# 第三阶段
287+
event_pack.notify_forward_and_wait_post_handle()
288+
sync_event.synchronize()
289+
243290
self._post_handle(
244291
run_reqs=verify_ok_reqs,
245292
next_token_ids=next_token_ids_cpu[verify_ok_req_indexes],
246293
next_token_logprobs=next_token_logprobs_cpu[verify_ok_req_indexes],
247294
run_reqs_update_packs=update_packs,
248295
extra_post_req_handle_func=self.extra_post_req_handle_func,
249296
)
250-
251-
# share some inference info with the main model
252-
draft_model_input = model_input
253-
draft_model_output = model_output
254-
draft_next_token_ids = next_token_ids_gpu
255-
# process the draft model output
256-
for draft_model_idx in range(self.mtp_step):
257-
258-
draft_model_input.input_ids = draft_next_token_ids
259-
draft_model_input.deepseekv3_mtp_draft_input_hiddens = draft_model_output.deepseekv3_mtp_main_output_hiddens
260-
# spec decode: MTP
261-
draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input)
262-
draft_next_token_ids, draft_next_token_ids_cpu = self._gen_argmax_token_ids(draft_model_output)
263-
264-
unique_reqs = [run_reqs[index] for index in verify_ok_req_last_indexes]
265-
self._update_reqs_mtp_gen_token_ids(
266-
reqs=unique_reqs, mtp_draft_next_token_ids=draft_next_token_ids_cpu[verify_ok_req_last_indexes]
267-
)
268-
269297
if need_free_mem_indexes:
270298
g_infer_state_lock.acquire()
271299
g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes)
272300
g_infer_state_lock.release()
301+
302+
# 第四阶段
303+
event_pack.notify_pre_post_handle()
273304
return

0 commit comments

Comments
 (0)