Skip to content

Commit e032bf7

Browse files
committed
fix mtp norm and fix chunked
1 parent 4f12cd4 commit e032bf7

File tree

5 files changed

+65
-50
lines changed

5 files changed

+65
-50
lines changed

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,12 +320,13 @@ def get_chuncked_input_token_ids(self):
320320
chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size)
321321
return self.shm_req.shm_prompt_ids.arr[0:chunked_end]
322322

323-
def get_chunked_input_token_ids_shift(self, shift=-1):
323+
def get_chunked_input_token_ids_shift(self, shift=1):
324324
input_ids = self.get_input_token_ids()
325-
shift_input_ids = np.roll(input_ids, shift)
325+
shift_input_ids = np.roll(input_ids, -1 * shift)
326326
chunked_start = self.cur_kv_len
327327
chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size)
328-
return shift_input_ids[shift:chunked_end]
328+
is_last_chunked = chunked_end == self.get_cur_total_len() + shift
329+
return shift_input_ids[0:chunked_end], is_last_chunked
329330

330331
def get_chuncked_input_token_len(self):
331332
chunked_start = self.cur_kv_len

lightllm/server/router/model_infer/mode_backend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
from .continues_batch.pd_mode.prefill_node_impl.prefill_impl_for_dp_chuncked import DPChunkedForPrefillNode
1414
from .continues_batch.pd_mode.decode_node_impl.decode_impl_for_dp import DPForDecodeNode
1515
from .continues_batch.impl_mtp import ContinuesBatchWithMTPBackend
16+
from .chunked_prefill.impl_mtp import ChunkedPrefillWithMTPBackend

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_mtp.py

Lines changed: 23 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
prepare_prefill_inputs,
1010
)
1111
from lightllm.server.router.model_infer.mode_backend.mtp_pre_process import (
12-
prepare_mtp_prefill_inputs,
12+
prepare_mtp_chunked_prefill_inputs,
1313
prepare_draft_main_model_decode_inputs,
1414
)
1515
from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample
@@ -38,7 +38,7 @@ def decode(self):
3838

3939
if prefill_reqs:
4040
model_input, run_reqs = prepare_prefill_inputs(
41-
prefill_reqs, is_chuncked_mode=False, is_multimodal=self.is_multimodal
41+
prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
4242
)
4343
model_output = self.model.forward(model_input)
4444

@@ -49,27 +49,37 @@ def decode(self):
4949
next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id)
5050
next_token_ids = next_token_ids.detach().cpu().numpy()
5151
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
52-
self._post_handle(
53-
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False
54-
)
52+
prev_step_has_output = [
53+
req_obj.get_chuncked_input_token_len() == req_obj.get_cur_total_len() for req_obj in prefill_reqs
54+
]
5555
# spec prefill: MTP
5656
last_input_ids_cpu = None
5757
draft_model_input = model_input
5858
last_hidden_states = model_output.hidden_states
59+
draft_next_token_ids = next_token_ids
5960
for draft_model_idx in range(self.spec_step):
60-
device0_print(f"main {draft_model_input}")
61-
draft_model_input, last_input_ids_cpu = prepare_mtp_prefill_inputs(
62-
prefill_reqs, model_input, last_hidden_states, next_token_ids, last_input_ids_cpu
61+
62+
draft_model_input, last_input_ids_cpu, prev_step_has_output = prepare_mtp_chunked_prefill_inputs(
63+
prefill_reqs,
64+
model_input,
65+
last_hidden_states,
66+
draft_next_token_ids,
67+
draft_model_idx + 1,
68+
prev_step_has_output,
69+
last_input_ids_cpu,
6370
)
64-
device0_print(f"draft_model_input {draft_model_input}")
71+
6572
draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input)
6673
draft_next_token_ids, _ = sample(draft_model_output.logits, run_reqs, self.eos_id)
6774
draft_next_token_ids = draft_next_token_ids.detach().cpu().numpy()
6875

6976
last_hidden_states = draft_model_output.hidden_states
70-
next_token_ids = draft_next_token_ids
7177
self._save_draft_token_ids(draft_next_token_ids, run_reqs, draft_model_idx)
7278

79+
self._post_handle(
80+
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False
81+
)
82+
7383
if decode_reqs:
7484
model_input, run_reqs, mem_indexes_cpu = prepare_draft_main_model_decode_inputs(
7585
decode_reqs, self.draft_token_id_map
@@ -93,9 +103,11 @@ def decode(self):
93103
accepted_reqs,
94104
next_token_ids[accepted_index],
95105
next_token_logprobs[accepted_index],
96-
is_chuncked_mode=False,
106+
is_chuncked_mode=True,
97107
do_filter_finished_reqs=False,
98108
)
109+
self.main_step += 1
110+
99111
# share some inference info with the main model
100112
draft_model_input = model_input
101113
draft_model_input.input_ids = next_token_ids_cuda
@@ -118,37 +130,3 @@ def decode(self):
118130

119131
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
120132
return
121-
122-
def verify(self, next_token_ids, run_reqs, draft_mem_indexes):
123-
accepted_reqs = []
124-
accepted_index = []
125-
need_free_mem_indexes = []
126-
assert next_token_ids.shape[0] % self.spec_stride == 0
127-
128-
for i, req in enumerate(run_reqs):
129-
# main model output
130-
if i % self.spec_stride == 0:
131-
accepted_reqs.append(req)
132-
accepted_index.append(i)
133-
continue
134-
draft_model_idx = i % self.spec_stride - 1
135-
if (
136-
self.draft_token_id_map[req.req_idx][draft_model_idx] == next_token_ids[i - 1]
137-
and req.cur_accepted_len == draft_model_idx
138-
):
139-
accepted_reqs.append(req)
140-
accepted_index.append(i)
141-
req.cur_accepted_len += 1
142-
device0_print(f"req {req.req_idx} accepted, cur_accepted_len {req.cur_accepted_len}")
143-
else:
144-
need_free_mem_indexes.append(draft_mem_indexes[i])
145-
return accepted_reqs, accepted_index, need_free_mem_indexes
146-
147-
def _save_draft_token_ids(self, draft_next_token_ids, run_reqs, draft_model_idx):
148-
batch_size = len(run_reqs) // self.spec_stride
149-
for i in range(batch_size):
150-
req = run_reqs[self.spec_stride * i]
151-
self.draft_token_id_map[req.req_idx][draft_model_idx] = draft_next_token_ids[i + req.cur_accepted_len]
152-
# reset the cur_accepted_len
153-
if draft_model_idx == self.spec_step - 1:
154-
req.cur_accepted_len = 0

lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,49 @@ def prepare_mtp_prefill_inputs(
1515
for i, req in enumerate(req_objs):
1616
if last_input_ids_cpu is None:
1717
input_token_ids = req.get_input_token_ids()
18+
else:
19+
input_token_ids = last_input_ids_cpu[i]
20+
input_token_ids = np.roll(input_token_ids, -1)
21+
input_token_ids[-1] = tgt_input_ids[i]
22+
input_ids.append(input_token_ids[req.cur_kv_len :])
23+
input_ids_cpu = input_ids
24+
input_ids = np.concatenate(input_ids, dtype=np.int64)
25+
input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda")
26+
model_input.input_ids = input_ids
27+
# mtp embedding
28+
model_input.hidden_states = last_hidden_states
29+
return model_input, input_ids_cpu
30+
31+
32+
def prepare_mtp_chunked_prefill_inputs(
33+
req_objs: List[InferReq],
34+
model_input: ModelInput,
35+
last_hidden_states,
36+
tgt_input_ids,
37+
shift,
38+
prev_step_has_output,
39+
last_input_ids_cpu=None,
40+
):
41+
input_ids = []
42+
for i, req in enumerate(req_objs):
43+
if last_input_ids_cpu is None or not prev_step_has_output[i]:
44+
input_token_ids, is_last_chunked = req.get_chunked_input_token_ids_shift(shift)
45+
if prev_step_has_output[i]:
46+
input_token_ids[-1] = tgt_input_ids[i]
47+
prev_step_has_output[i] = is_last_chunked
1848
else:
1949
input_token_ids = last_input_ids_cpu[i]
2050
input_token_ids = np.roll(input_token_ids, -1)
2151
input_token_ids[-1] = tgt_input_ids[i]
52+
prev_step_has_output[i] = True
2253
input_ids.append(input_token_ids[req.cur_kv_len :])
2354
input_ids_cpu = input_ids
2455
input_ids = np.concatenate(input_ids, dtype=np.int64)
2556
input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda")
2657
model_input.input_ids = input_ids
2758
# mtp embedding
2859
model_input.hidden_states = last_hidden_states
29-
return model_input, input_ids_cpu
60+
return model_input, input_ids_cpu, prev_step_has_output
3061

3162

3263
def prepare_draft_main_model_decode_inputs(req_objs: List[InferReq], draft_token_id_map):

lightllm/server/router/model_infer/model_rpc.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ChunckedPrefillForPrefillNode,
2323
DPChunkedForPrefillNode,
2424
ContinuesBatchWithMTPBackend,
25+
ChunkedPrefillWithMTPBackend,
2526
)
2627
from lightllm.server.router.model_infer.mode_backend.redundancy_expert_manager import RedundancyExpertManager
2728
from lightllm.server.core.objs import RpcShmParams, RpcShmResults, ShmSyncStatusArray
@@ -161,7 +162,10 @@ def init_model(self, kvargs):
161162
else:
162163
self.backend = ContinuesBatchBackend()
163164
else:
164-
self.backend = ChunkedPrefillBackend()
165+
if kvargs.get("spec_algo", "NONE") == "MTP":
166+
self.backend = ChunkedPrefillWithMTPBackend()
167+
else:
168+
self.backend = ChunkedPrefillBackend()
165169

166170
logger.info(f"use {self.backend.__class__.__name__}")
167171
self.backend.init_model(kvargs)

0 commit comments

Comments
 (0)