Skip to content

Commit 1a0d132

Browse files
committed
mutli step mtp and dynamic_prompt cache for mtp
1 parent 7ff2329 commit 1a0d132

File tree

8 files changed

+296
-98
lines changed

8 files changed

+296
-98
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,12 @@ def _context_forward(self, input_ids, infer_state: InferStateInfo):
450450
predict_logits = post_method(input_embs, infer_state, self.pre_post_weight)
451451

452452
g_cache_manager.cache_env_out()
453+
is_return_hidden_states = self.spec_algo.is_mtp() or (
454+
self.spec_algo.is_mtp_module() and not self.last_mtp_module
455+
)
453456
return ModelOutput(
454457
logits=predict_logits,
455-
hidden_states=input_embs if self.spec_algo.is_mtp() else None,
458+
hidden_states=input_embs if is_return_hidden_states else None,
456459
)
457460

458461
@final
@@ -475,9 +478,12 @@ def _token_forward(self, input_ids, infer_state: InferStateInfo):
475478
predict_logits = post_method(input_embs, infer_state, self.pre_post_weight)
476479

477480
g_cache_manager.cache_env_out()
481+
is_return_hidden_states = self.spec_algo.is_mtp() or (
482+
self.spec_algo.is_mtp_module() and not self.last_mtp_module
483+
)
478484
return ModelOutput(
479485
logits=predict_logits,
480-
hidden_states=input_embs if self.spec_algo.is_mtp() else None,
486+
hidden_states=input_embs if is_return_hidden_states else None,
481487
)
482488

483489
@final

lightllm/models/deepseek_mtp/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ class Deepseek3MTPModel(Deepseek2TpPartModel):
2222
def __init__(self, kvargs):
2323
self.main_model = kvargs.pop("main_model")
2424
self.req_manager = self.main_model.req_manager
25+
self.last_mtp_module = kvargs.pop("last_mtp_module", False)
2526
super().__init__(kvargs)
2627

2728
def _init_req_manager(self):
2829
# draft model shares the same req_manager with the main model
2930
if hasattr(self, "req_manager"):
30-
print("SKIP INIT REQ!!!!!!!!")
3131
return
3232
create_max_seq_len = 0
3333

lightllm/server/core/objs/req.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ class Req(ctypes.Structure):
9494
("reward_score", ctypes.c_float),
9595
# 请求回复累计概率和
9696
("cumlogprob", ctypes.c_float),
97+
# mtp draft model 接受长度
98+
("mtp_accepted_len", ctypes.c_int),
9799
]
98100

99101
def get_str(self):
@@ -145,6 +147,7 @@ def init(
145147
self.create_prompt_ids_shm_array()
146148
self.chunked_prefill_size = chunked_prefill_size
147149
self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids
150+
self.mtp_accepted_len = 0
148151

149152
self.post_init()
150153

lightllm/server/httpserver/manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,8 @@ async def _wait_to_token_package(
540540
x_request_id = request.headers.get("X-Request-Id", "") if request is not None else ""
541541
x_session_id = request.headers.get("X-Session-Id", "") if request is not None else ""
542542
prompt_cache_ratio = prompt_cache_len / prompt_tokens
543+
544+
avg_token_per_step = out_token_counter / (out_token_counter - metadata["mtp_accepted_len"])
543545
format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S")
544546
logger.info(
545547
f"X-Request-Id:{x_request_id} "
@@ -550,6 +552,7 @@ async def _wait_to_token_package(
550552
f"prompt_token_num:{prompt_tokens} "
551553
f"prompt_cache_len:{prompt_cache_len} "
552554
f"prompt_cache_ratio:{prompt_cache_ratio} "
555+
f"avg_token_per_step:{avg_token_per_step} "
553556
)
554557
if group_request_id < 0:
555558
# health 探测请求,不记录日志和监控
@@ -652,6 +655,7 @@ async def handle_loop(self):
652655
"special": special,
653656
"count_output_tokens": count_output_tokens,
654657
"prompt_cache_len": req.prompt_cache_len,
658+
"mtp_accepted_len": req.mtp_accepted_len,
655659
}
656660
if self.args.return_all_prompt_logprobs:
657661
metadata.update(req.get_all_prompt_metadata())

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def __init__(
257257
self.vocab_size = vocab_size
258258
self.initialized = False
259259
self.paused = False
260+
self.cur_accepted_len = 0 # for mtp forward
260261

261262
def init_all(self):
262263
if self.initialized is False:
@@ -319,6 +320,13 @@ def get_chuncked_input_token_ids(self):
319320
chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size)
320321
return self.shm_req.shm_prompt_ids.arr[0:chunked_end]
321322

323+
def get_chunked_input_token_ids_shift(self, shift=-1):
324+
input_ids = self.get_input_token_ids()
325+
shift_input_ids = np.roll(input_ids, shift)
326+
chunked_start = self.cur_kv_len
327+
chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size)
328+
return shift_input_ids[shift:chunked_end]
329+
322330
def get_chuncked_input_token_len(self):
323331
chunked_start = self.cur_kv_len
324332
chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size)
@@ -330,6 +338,9 @@ def set_next_gen_token_id(self, next_token_id: int, logprob: float):
330338
self.shm_req.shm_logprobs.arr[index] = logprob
331339
return
332340

341+
def set_total_accepted_len(self):
342+
self.shm_req.mtp_accepted_len += self.cur_accepted_len
343+
333344
def get_last_gen_token(self):
334345
return self.shm_req.shm_prompt_ids.arr[self.shm_req.input_len + self.cur_output_len - 1]
335346

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import torch
2+
import numpy as np
3+
from typing import List, Tuple
4+
from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend
5+
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
6+
from lightllm.utils.log_utils import init_logger
7+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
8+
from lightllm.server.router.model_infer.mode_backend.generic_pre_process import (
9+
prepare_prefill_inputs,
10+
)
11+
from lightllm.server.router.model_infer.mode_backend.mtp_pre_process import (
12+
prepare_mtp_prefill_inputs,
13+
prepare_draft_main_model_decode_inputs,
14+
)
15+
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
16+
import os
17+
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
18+
from lightllm.server.router.model_infer.infer_batch import InferReq
19+
from lightllm.server.router.model_infer.mode_backend.continues_batch.impl_mtp import ContinuesBatchWithMTPBackend
20+
import copy
21+
from lightllm.utils.dist_utils import device0_print
22+
23+
24+
logger = init_logger(__name__)
25+
26+
27+
class ChunkedPrefillWithMTPBackend(ContinuesBatchWithMTPBackend):
28+
def __init__(self) -> None:
29+
super().__init__()
30+
31+
def decode(self):
32+
uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs(
33+
g_infer_context.infer_req_ids
34+
)
35+
36+
if aborted_reqs:
37+
g_infer_context.filter_reqs(aborted_reqs)
38+
39+
if prefill_reqs:
40+
model_input, run_reqs = prepare_prefill_inputs(
41+
prefill_reqs, is_chuncked_mode=False, is_multimodal=self.is_multimodal
42+
)
43+
model_output = self.model.forward(model_input)
44+
45+
self._overlap_req_init_and_filter(
46+
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
47+
)
48+
49+
next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id)
50+
next_token_ids = next_token_ids.detach().cpu().numpy()
51+
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
52+
self._post_handle(
53+
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False
54+
)
55+
# spec prefill: MTP
56+
last_input_ids_cpu = None
57+
draft_model_input = model_input
58+
last_hidden_states = model_output.hidden_states
59+
for draft_model_idx in range(self.spec_step):
60+
device0_print(f"main {draft_model_input}")
61+
draft_model_input, last_input_ids_cpu = prepare_mtp_prefill_inputs(
62+
prefill_reqs, model_input, last_hidden_states, next_token_ids, last_input_ids_cpu
63+
)
64+
device0_print(f"draft_model_input {draft_model_input}")
65+
draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input)
66+
draft_next_token_ids, _ = sample(draft_model_output.logits, run_reqs, self.eos_id)
67+
draft_next_token_ids = draft_next_token_ids.detach().cpu().numpy()
68+
69+
last_hidden_states = draft_model_output.hidden_states
70+
next_token_ids = draft_next_token_ids
71+
self._save_draft_token_ids(draft_next_token_ids, run_reqs, draft_model_idx)
72+
73+
if decode_reqs:
74+
model_input, run_reqs, mem_indexes_cpu = prepare_draft_main_model_decode_inputs(
75+
decode_reqs, self.draft_token_id_map
76+
)
77+
model_output = self.model.forward(model_input)
78+
assert model_output.logits.shape[0] % self.spec_stride == 0
79+
80+
self._overlap_req_init_and_filter(
81+
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
82+
)
83+
84+
next_token_ids_cuda, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id)
85+
next_token_ids = next_token_ids_cuda.detach().cpu().numpy()
86+
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
87+
88+
# verify
89+
accepted_reqs, accepted_index, need_free_mem_indexes = self.verify(
90+
next_token_ids, run_reqs, mem_indexes_cpu
91+
)
92+
self._post_handle(
93+
accepted_reqs,
94+
next_token_ids[accepted_index],
95+
next_token_logprobs[accepted_index],
96+
is_chuncked_mode=False,
97+
do_filter_finished_reqs=False,
98+
)
99+
# share some inference info with the main model
100+
draft_model_input = model_input
101+
draft_model_input.input_ids = next_token_ids_cuda
102+
draft_model_input.hidden_states = model_output.hidden_states
103+
# process the draft model output
104+
for draft_model_idx in range(self.spec_step):
105+
# spec decode: MTP
106+
draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input)
107+
draft_next_token_ids, _ = sample(draft_model_output.logits, run_reqs, self.eos_id)
108+
# prepare inputs for the next draft model
109+
draft_model_input.input_ids = draft_next_token_ids
110+
draft_model_input.hidden_states = draft_model_output.hidden_states
111+
draft_next_token_ids_numpy = draft_next_token_ids.detach().cpu().numpy()
112+
self._save_draft_token_ids(draft_next_token_ids_numpy, run_reqs, draft_model_idx)
113+
114+
if need_free_mem_indexes:
115+
g_infer_state_lock.acquire()
116+
g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes)
117+
g_infer_state_lock.release()
118+
119+
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
120+
return
121+
122+
def verify(self, next_token_ids, run_reqs, draft_mem_indexes):
123+
accepted_reqs = []
124+
accepted_index = []
125+
need_free_mem_indexes = []
126+
assert next_token_ids.shape[0] % self.spec_stride == 0
127+
128+
for i, req in enumerate(run_reqs):
129+
# main model output
130+
if i % self.spec_stride == 0:
131+
accepted_reqs.append(req)
132+
accepted_index.append(i)
133+
continue
134+
draft_model_idx = i % self.spec_stride - 1
135+
if (
136+
self.draft_token_id_map[req.req_idx][draft_model_idx] == next_token_ids[i - 1]
137+
and req.cur_accepted_len == draft_model_idx
138+
):
139+
accepted_reqs.append(req)
140+
accepted_index.append(i)
141+
req.cur_accepted_len += 1
142+
device0_print(f"req {req.req_idx} accepted, cur_accepted_len {req.cur_accepted_len}")
143+
else:
144+
need_free_mem_indexes.append(draft_mem_indexes[i])
145+
return accepted_reqs, accepted_index, need_free_mem_indexes
146+
147+
def _save_draft_token_ids(self, draft_next_token_ids, run_reqs, draft_model_idx):
148+
batch_size = len(run_reqs) // self.spec_stride
149+
for i in range(batch_size):
150+
req = run_reqs[self.spec_stride * i]
151+
self.draft_token_id_map[req.req_idx][draft_model_idx] = draft_next_token_ids[i + req.cur_accepted_len]
152+
# reset the cur_accepted_len
153+
if draft_model_idx == self.spec_step - 1:
154+
req.cur_accepted_len = 0

0 commit comments

Comments
 (0)