Skip to content

Commit 9c48a23

Browse files
committed
decode for mtp
1 parent 9fb95d1 commit 9c48a23

File tree

2 files changed

+8
-16
lines changed

2 files changed

+8
-16
lines changed

lightllm/server/router/manager.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def __init__(self, args, router_port, detokenization_port, metric_port):
6262
# 初始化 radix_cache_client 用于读取 prompt cache 的管理信息
6363
self.radix_cache_client = None
6464

65+
self.spec_step = args.spec_step
66+
6567
# 共享变量,用于存储router端调度分析得到的机器负载信息
6668
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node)
6769
for dp_index in range(self.dp_size_in_node):
@@ -386,8 +388,9 @@ async def _prefill_batch(self, batch: Batch):
386388
self.overlap_event.set()
387389
await self.model_rpc_client.prefill(reqs)
388390
batch.filter_out_finished_req(self.shm_req_manager)
389-
# 发个None包触发一下detokenization
390-
self.send_to_detokenization.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL)
391+
# 发spec_step + 1 个 None包触发一下detokenization
392+
for _ in range(self.spec_step + 1):
393+
self.send_to_detokenization.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL)
391394

392395
logger.debug(f"Prefill Batch: {batch.simple_log()} \n")
393396
self.metric_client.histogram_observe(
@@ -403,8 +406,9 @@ async def _decode_batch(self, batch: Batch):
403406
# 在 self.is_multinode_and_multidp 为 True 时,传入的 batch 对象可能为 None。
404407
if batch is not None:
405408
batch.filter_out_finished_req(self.shm_req_manager)
406-
# 发个None包触发一下detokenization
407-
self.send_to_detokenization.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL)
409+
# 发spec_step + 1 个 None包触发一下detokenization
410+
for _ in range(self.spec_step + 1):
411+
self.send_to_detokenization.send_pyobj(None, protocol=pickle.HIGHEST_PROTOCOL)
408412
self.metric_client.histogram_observe(
409413
"lightllm_batch_inference_duration_bucket", time.time() - start_time, "decode"
410414
)

lightllm/server/router/model_infer/mode_backend/continues_batch/impl_mtp.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ def update_draft_token_mem_indexes(draft_token_memindex_map, run_reqs, mem_index
3333
class ContinuesBatchWithMTPBackend(ModeBackend):
3434
def __init__(self) -> None:
3535
super().__init__()
36-
self.accepted_cnt = 0
37-
self.all_cnt = 0
3836

3937
# 支持双模型
4038
def init_model(self, kvargs):
@@ -83,8 +81,6 @@ def init_model(self, kvargs):
8381
self.mtp_draft_token_memindex_map = torch.full(
8482
(max_req_num,), fill_value=IS_NONE, dtype=torch.int32, device="cpu"
8583
)
86-
self.draft_accept_count = torch.zeros((max_req_num,), dtype=torch.int32, device="cpu")
87-
self.main_step = 0
8884

8985
def prefill(self, reqs: List[Tuple]):
9086
self._init_reqs(reqs, init_req_obj=False)
@@ -103,8 +99,6 @@ def decode(self):
10399
prefill_reqs, is_chuncked_mode=False, is_multimodal=self.is_multimodal
104100
)
105101
model_output = self.model.forward(model_input)
106-
self.main_step += 1
107-
device0_print(f"main_step: {self.main_step}")
108102

109103
self._overlap_req_init_and_filter(
110104
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
@@ -134,9 +128,6 @@ def decode(self):
134128
model_output = self.model.forward(model_input)
135129
assert model_output.logits.shape[0] % 2 == 0
136130

137-
self.main_step += 1
138-
device0_print(f"main_step: {self.main_step}")
139-
140131
self._overlap_req_init_and_filter(
141132
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
142133
)
@@ -165,7 +156,6 @@ def decode(self):
165156
is_chuncked_mode=False,
166157
do_filter_finished_reqs=False,
167158
)
168-
169159
# spec decode: MTP
170160
draft_model_input = copy.deepcopy(model_input)
171161
draft_model_input.input_ids = torch.tensor(next_token_ids, dtype=torch.int64, device="cuda")
@@ -191,8 +181,6 @@ def verify(self, next_token_ids0, run_reqs):
191181
if self.draft_token_id_map[req.req_idx] == next_token_ids0[i]:
192182
accepted_reqs.append(req)
193183
accepted_index.append(i)
194-
self.draft_accept_count[req.req_idx] += 1
195-
device0_print(f"draft_accept_count: {self.draft_accept_count[req.req_idx]}")
196184
self.main_draft_token_memindex_map[req.req_idx] = IS_NONE
197185
else:
198186
need_free_mem_indexes.append(self.main_draft_token_memindex_map[req.req_idx])

0 commit comments

Comments
 (0)