Skip to content

Commit 7ff2329

Browse files
committed
share mem_index between draft and main
1 parent 9c48a23 commit 7ff2329

File tree

3 files changed

+47
-110
lines changed

3 files changed

+47
-110
lines changed

lightllm/models/deepseek_mtp/model.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from lightllm.distributed.communication_op import CustomProcessGroup, dist_group_manager
99
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req
10+
from lightllm.common.req_manager import ReqManager
1011
from lightllm.common.infer_utils import init_req_to_token_indexes
1112
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
1213
from lightllm.common.basemodel.cuda_graph import CudaGraph
@@ -19,9 +20,25 @@ class Deepseek3MTPModel(Deepseek2TpPartModel):
1920
pre_layer_infer_class = Deepseek3MTPPreLayerInfer
2021

2122
def __init__(self, kvargs):
22-
self.main_model = kvargs["main_model"]
23+
self.main_model = kvargs.pop("main_model")
24+
self.req_manager = self.main_model.req_manager
2325
super().__init__(kvargs)
2426

27+
def _init_req_manager(self):
28+
# draft model shares the same req_manager with the main model
29+
if hasattr(self, "req_manager"):
30+
print("SKIP INIT REQ!!!!!!!!")
31+
return
32+
create_max_seq_len = 0
33+
34+
if self.batch_max_tokens is not None:
35+
create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens)
36+
if self.max_seq_length is not None:
37+
create_max_seq_len = max(create_max_seq_len, self.max_seq_length)
38+
39+
self.req_manager = ReqManager(self.max_req_num, create_max_seq_len, self.mem_manager)
40+
return
41+
2542
def _init_mem_manager(self):
2643
self.mem_manager = Deepseek3MTPMemoryManager(
2744
self.max_total_token_num,

lightllm/server/router/model_infer/mode_backend/continues_batch/impl_mtp.py

Lines changed: 18 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,6 @@
2424

2525
logger = init_logger(__name__)
2626

27-
# TODO: optim
28-
def update_draft_token_mem_indexes(draft_token_memindex_map, run_reqs, mem_indexes):
29-
for i, req in enumerate(run_reqs):
30-
draft_token_memindex_map[req.req_idx] = mem_indexes[i]
31-
3227

3328
class ContinuesBatchWithMTPBackend(ModeBackend):
3429
def __init__(self) -> None:
@@ -81,6 +76,7 @@ def init_model(self, kvargs):
8176
self.mtp_draft_token_memindex_map = torch.full(
8277
(max_req_num,), fill_value=IS_NONE, dtype=torch.int32, device="cpu"
8378
)
79+
self.accept_len = 0
8480

8581
def prefill(self, reqs: List[Tuple]):
8682
self._init_reqs(reqs, init_req_obj=False)
@@ -107,10 +103,8 @@ def decode(self):
107103
next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id)
108104
next_token_ids = next_token_ids.detach().cpu().numpy()
109105
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
110-
# spec decode: MTP
111-
draft_model_input = prepare_mtp_prefill_inputs(prefill_reqs, next_token_ids, self.draft_model.mem_manager)
112-
# mtp embedding
113-
draft_model_input.hidden_states = model_output.hidden_states
106+
# spec prefill: MTP
107+
draft_model_input = prepare_mtp_prefill_inputs(prefill_reqs, model_input, model_output, next_token_ids)
114108
draft_model_output = self.draft_model.forward(draft_model_input)
115109
draft_next_token_ids, _ = sample(draft_model_output.logits, run_reqs, self.eos_id)
116110
draft_next_token_ids = draft_next_token_ids.detach().cpu().numpy()
@@ -121,9 +115,8 @@ def decode(self):
121115
)
122116

123117
if decode_reqs:
124-
model_input, run_reqs = prepare_draft_main_model_decode_inputs(decode_reqs, self.draft_token_id_map)
125-
update_draft_token_mem_indexes(
126-
self.main_draft_token_memindex_map, run_reqs[1::2], model_input.mem_indexes[1::2]
118+
model_input, run_reqs, mem_indexes_cpu = prepare_draft_main_model_decode_inputs(
119+
decode_reqs, self.draft_token_id_map
127120
)
128121
model_output = self.model.forward(model_input)
129122
assert model_output.logits.shape[0] % 2 == 0
@@ -148,7 +141,9 @@ def decode(self):
148141
next_token_ids1 = next_token_ids[1::2]
149142
next_token_logprobs1 = next_token_logprobs[1::2]
150143

151-
accepted_reqs, accepted_index = self.verify(next_token_ids0, run_reqs[::2])
144+
accepted_reqs, accepted_index, need_free_mem_indexes = self.verify(
145+
next_token_ids0, run_reqs[::2], mem_indexes_cpu[1::2]
146+
)
152147
self._post_handle(
153148
accepted_reqs,
154149
next_token_ids1[accepted_index],
@@ -157,22 +152,23 @@ def decode(self):
157152
do_filter_finished_reqs=False,
158153
)
159154
# spec decode: MTP
160-
draft_model_input = copy.deepcopy(model_input)
155+
draft_model_input = model_input
161156
draft_model_input.input_ids = torch.tensor(next_token_ids, dtype=torch.int64, device="cuda")
162-
mtp_mem_indexes = self.draft_model.mem_manager.alloc(next_token_ids.shape[0]).cuda()
163-
draft_model_input.mem_indexes = mtp_mem_indexes
164157
draft_model_input.hidden_states = model_output.hidden_states
165-
update_draft_token_mem_indexes(self.mtp_draft_token_memindex_map, run_reqs[1::2], mtp_mem_indexes[1::2])
166158
draft_model_output = self.draft_model.forward(draft_model_input)
167159
draft_next_token_ids, _ = sample(draft_model_output.logits, run_reqs, self.eos_id)
168160

169161
accepted_req_idxs = [req.req_idx for req in accepted_reqs]
170162
self._save_draft_token_ids(draft_next_token_ids, run_reqs[::2], accepted_req_idxs)
163+
if need_free_mem_indexes:
164+
g_infer_state_lock.acquire()
165+
g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes)
166+
g_infer_state_lock.release()
171167

172168
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
173169
return
174170

175-
def verify(self, next_token_ids0, run_reqs):
171+
def verify(self, next_token_ids0, run_reqs, draft_mem_indexes):
176172
accepted_reqs = []
177173
accepted_index = []
178174
need_free_mem_indexes = []
@@ -181,66 +177,20 @@ def verify(self, next_token_ids0, run_reqs):
181177
if self.draft_token_id_map[req.req_idx] == next_token_ids0[i]:
182178
accepted_reqs.append(req)
183179
accepted_index.append(i)
184-
self.main_draft_token_memindex_map[req.req_idx] = IS_NONE
180+
self.accept_len += 1
181+
device0_print(f"self.accept_len: {self.accept_len}")
185182
else:
186-
need_free_mem_indexes.append(self.main_draft_token_memindex_map[req.req_idx])
187-
if need_free_mem_indexes:
188-
g_infer_state_lock.acquire()
189-
g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes)
190-
g_infer_state_lock.release()
191-
return accepted_reqs, accepted_index
183+
need_free_mem_indexes.append(draft_mem_indexes[i])
184+
return accepted_reqs, accepted_index, need_free_mem_indexes
192185

193186
def _save_draft_token_ids(self, draft_next_token_ids, run_reqs, accepted_reqs=None):
194187
assert accepted_reqs is None or draft_next_token_ids.shape[0] == 2 * len(run_reqs)
195-
need_free_mem_indexes = []
196188
for i, req in enumerate(run_reqs):
197189
if accepted_reqs is None:
198190
self.draft_token_id_map[req.req_idx] = draft_next_token_ids[i]
199191
else:
200192
if req.req_idx in accepted_reqs:
201193
self.draft_token_id_map[req.req_idx] = draft_next_token_ids[2 * i + 1]
202-
self.mtp_draft_token_memindex_map[req.req_idx] = IS_NONE
203194
else:
204195
self.draft_token_id_map[req.req_idx] = draft_next_token_ids[2 * i]
205-
need_free_mem_indexes.append(self.mtp_draft_token_memindex_map[req.req_idx])
206-
207-
req = run_reqs[0]
208-
if need_free_mem_indexes:
209-
g_infer_state_lock.acquire()
210-
self.draft_model.mem_manager.free(need_free_mem_indexes)
211-
g_infer_state_lock.release()
212-
return
213-
214-
def _overlap_req_init_and_filter(
215-
self, uninit_reqs: List[InferReq], ok_finished_reqs: List[InferReq], clear_list=False
216-
):
217-
if uninit_reqs or ok_finished_reqs:
218-
with torch.cuda.stream(g_infer_context.get_overlap_stream()):
219-
if ok_finished_reqs:
220-
g_infer_state_lock.acquire()
221-
self._free_mtp_model_memindex(ok_finished_reqs)
222-
g_infer_context.filter_reqs(ok_finished_reqs)
223-
g_infer_state_lock.release()
224-
225-
if uninit_reqs:
226-
g_infer_state_lock.acquire()
227-
self._post_init_reqs(uninit_reqs)
228-
g_infer_state_lock.release()
229-
230-
torch.cuda.current_stream().wait_stream(g_infer_context.get_overlap_stream())
231-
232-
if clear_list:
233-
uninit_reqs.clear()
234-
ok_finished_reqs.clear()
235196
return
236-
237-
def _free_mtp_model_memindex(self, ok_finished_reqs):
238-
mtp_free_mem_indexes = []
239-
for req in ok_finished_reqs:
240-
mtp_free_mem_indexes.append(
241-
self.draft_model.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]
242-
)
243-
free_memindexes = torch.cat(mtp_free_mem_indexes, dim=0)
244-
g_infer_state_lock.acquire()
245-
self.draft_model.req_manager.mem_manager.free(free_memindexes)
246-
g_infer_state_lock.release()

lightllm/server/router/model_infer/mode_backend/mtp_pre_process.py

Lines changed: 11 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,58 +3,27 @@
33
from typing import List, Tuple
44
from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context
55
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
6-
from lightllm.common.basemodel.batch_objs import ModelInput
6+
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
77

88
IS_NONE = -1
99

1010

11-
def prepare_mtp_prefill_inputs(req_objs: List[InferReq], tgt_input_ids, mem_manager):
12-
nopad_total_token_num = 0
13-
nopad_max_len_in_batch = 0
11+
def prepare_mtp_prefill_inputs(
12+
req_objs: List[Tuple], model_input: ModelInput, model_output: ModelOutput, tgt_input_ids
13+
):
1414
input_ids = []
15-
nopad_b_req_idx = []
16-
nopad_b_seq_len = []
17-
batch_multimodal_params = []
1815
for i, req in enumerate(req_objs):
19-
batch_multimodal_params.append(req.multimodal_params)
20-
nopad_b_req_idx.append(req.req_idx)
21-
2216
input_token_ids = req.get_input_token_ids()
2317

2418
input_token_ids = np.roll(input_token_ids, -1)
2519
input_token_ids[-1] = tgt_input_ids[i]
26-
27-
seq_len = len(input_token_ids)
28-
input_token_len = seq_len
29-
30-
input_id = input_token_ids
31-
32-
nopad_b_seq_len.append(seq_len)
33-
input_ids.append(input_id)
34-
nopad_total_token_num += seq_len
35-
nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_token_len)
20+
input_ids.append(input_token_ids)
3621

3722
input_ids = np.concatenate(input_ids, dtype=np.int64)
38-
3923
input_ids = torch.tensor(input_ids, dtype=torch.int64, device="cuda")
40-
nopad_b_req_idx = torch.tensor(nopad_b_req_idx, dtype=torch.int32, device="cuda")
41-
nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda")
42-
b_ready_cache_len = torch.zeros(len(req_objs), dtype=torch.int32, device="cuda")
43-
44-
g_infer_state_lock.acquire()
45-
mem_indexes = mem_manager.alloc(input_ids.shape[0]).cuda()
46-
g_infer_state_lock.release()
47-
model_input = ModelInput(
48-
batch_size=len(req_objs),
49-
total_token_num=nopad_total_token_num,
50-
max_len_in_batch=nopad_max_len_in_batch,
51-
input_ids=input_ids,
52-
mem_indexes=mem_indexes,
53-
b_req_idx=nopad_b_req_idx,
54-
b_seq_len=nopad_b_seq_len,
55-
b_ready_cache_len=b_ready_cache_len,
56-
is_prefill=True,
57-
)
24+
model_input.input_ids = input_ids
25+
# mtp embedding
26+
model_input.hidden_states = model_output.hidden_states
5827
return model_input
5928

6029

@@ -95,7 +64,8 @@ def prepare_draft_main_model_decode_inputs(req_objs: List[Tuple], draft_token_id
9564
g_infer_state_lock.acquire()
9665
if g_infer_context.radix_cache is not None:
9766
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0])
98-
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]).cuda()
67+
mem_indexes_cpu = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0])
68+
mem_indexes = mem_indexes_cpu.cuda()
9969
g_infer_state_lock.release()
10070
model_input = ModelInput(
10171
batch_size=len(run_reqs),
@@ -107,4 +77,4 @@ def prepare_draft_main_model_decode_inputs(req_objs: List[Tuple], draft_token_id
10777
b_seq_len=nopad_b_seq_len,
10878
is_prefill=False,
10979
)
110-
return model_input, run_reqs
80+
return model_input, run_reqs, mem_indexes_cpu

0 commit comments

Comments
 (0)