Skip to content

Commit 50591d3

Browse files
committed
fix
1 parent e186bed commit 50591d3

File tree

6 files changed

+41
-125
lines changed

6 files changed

+41
-125
lines changed

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def get_overlap_stream(self) -> torch.cuda.Stream:
5353
self.overlap_stream = torch.cuda.Stream()
5454
return self.overlap_stream
5555

56-
def add_reqs(self, requests: List[Tuple[int, int, Any, int]]):
56+
def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]:
57+
req_objs = []
5758
request_ids = []
5859
for r in requests:
5960
r_id, r_index, multimodal_params, _ = r
@@ -64,12 +65,14 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]]):
6465
shm_index=r_index,
6566
multimodal_params=multimodal_params,
6667
vocab_size=self.vocab_size,
68+
init_prefix_cache=init_prefix_cache,
6769
)
6870
self.requests_mapping[r_id] = r_obj
6971
request_ids.append(r_id)
72+
req_objs.append(r_obj)
7073

7174
self.infer_req_ids.extend(request_ids)
72-
return
75+
return req_objs
7376

7477
def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finished: bool):
7578
if self.radix_cache is None:
@@ -261,6 +264,7 @@ def __init__(
261264
shm_index: int,
262265
multimodal_params=None,
263266
vocab_size: int = -1,
267+
init_prefix_cache: bool = True,
264268
):
265269
self.req_id = req_id
266270
self.req_idx = req_idx
@@ -285,7 +289,8 @@ def __init__(
285289
self.mtp_gen_token_ids: List[int] = []
286290

287291
self._init_all_state()
288-
self._match_radix_cache()
292+
if init_prefix_cache:
293+
self._match_radix_cache()
289294
return
290295

291296
def _init_all_state(self):

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,6 @@ def _try_read_new_reqs(self):
252252
req.infer_aborted = True
253253
else:
254254
self._init_reqs(reqs=cmds)
255-
self.chunked_prefill_state.need_prefill_count += 1
256255
return
257256

258257
# 一些可以复用的通用功能函数

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
import torch.multiprocessing as mp
44
import torch.distributed as dist
55
import threading
6-
from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend
7-
from lightllm.server.router.model_infer.mode_backend.continues_batch.impl import ContinuesBatchBackend
6+
from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ChunkedPrefillBackend
87
from typing import List, Tuple
9-
from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq
8+
from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, g_infer_state_lock
109
from lightllm.server.core.objs import FinishStatus
1110
from lightllm.utils.log_utils import init_logger
1211
from rpyc.utils.server import ThreadedServer
@@ -19,7 +18,7 @@
1918
logger = init_logger(__name__)
2019

2120

22-
class ContinuesBatchBackendForDecodeNode(ModeBackend):
21+
class DecodeNode(ChunkedPrefillBackend):
2322
def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None:
2423
super().__init__()
2524
self.info_queue: mp.Queue = info_queue
@@ -48,23 +47,23 @@ def init_custom(self):
4847

4948
return
5049

51-
def decode(self):
52-
uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs(
53-
g_infer_context.infer_req_ids,
54-
no_decode=False,
55-
)
56-
# p d 分离模式下, decode 节点不可能存在需要prefill操作的请求
57-
assert len(prefill_reqs) == 0
50+
def _init_reqs(self, reqs: List[Tuple]):
51+
"""
52+
替换请求初始化操作,替换为 Decode 节点独有的一些特殊初始化流程
53+
"""
54+
if self.dp_size_in_node != 1:
55+
dp_rank_in_node = self.dp_rank_in_node
56+
reqs = [req for req in reqs if req[3] == dp_rank_in_node]
5857

59-
self._filter_reqs(aborted_reqs)
58+
g_infer_state_lock.acquire()
6059

61-
if decode_reqs:
62-
ContinuesBatchBackend.normal_decode(
63-
self, decode_reqs=decode_reqs, uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs
64-
)
60+
uninit_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=False)
61+
# 匹配radix cache,并更新一些资源的管理。
62+
self._post_init_reqs(uninit_reqs=uninit_reqs)
6563

66-
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
67-
return
64+
g_infer_state_lock.release()
65+
req_ids = [e[0] for e in reqs]
66+
return req_ids
6867

6968
def _post_init_reqs(self, uninit_reqs: List[InferReq]):
7069
"""
@@ -85,12 +84,11 @@ def _post_init_reqs(self, uninit_reqs: List[InferReq]):
8584
req_all_len = len(task.input_tokens) + task.decode_node.max_new_tokens
8685
remove_count += req_all_len
8786
estimated_peak_token_count += req_all_len
88-
req_obj.init_all()
87+
req_obj._match_radix_cache()
8988
else:
9089
# 对于不合法的请求,直接模拟将其finished掉
91-
req_obj.init_all()
92-
req_obj.set_next_gen_token_id(0, 0.0)
9390
req_obj.cur_output_len += 1
91+
req_obj.set_next_gen_token_id(0, 0.0, 1)
9492
req_obj.finish_status.set_status(FinishStatus.FINISHED_STOP)
9593

9694
if self.is_master_in_dp:
Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,28 @@
1-
import torch
21
import torch.multiprocessing as mp
3-
import torch.distributed as dist
4-
from typing import List, Tuple
52
from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq
63
from lightllm.utils.log_utils import init_logger
7-
from lightllm.server.router.model_infer.mode_backend.pre import padded_prepare_prefill_inputs
8-
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
9-
from .decode_impl import ContinuesBatchBackendForDecodeNode
4+
from typing import List, Tuple
105
from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend
6+
from .decode_impl import DecodeNode
117

128
logger = init_logger(__name__)
139

1410

15-
class DPForDecodeNode(ContinuesBatchBackendForDecodeNode):
11+
class DPForDecodeNode(DPChunkedPrefillBackend):
1612
def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None:
17-
super().__init__(info_queue, mem_queue)
18-
self.enable_decode_microbatch_overlap = get_env_start_args().enable_decode_microbatch_overlap
13+
super().__init__()
14+
self.info_queue: mp.Queue = info_queue
15+
self.mem_queue: mp.Queue = mem_queue
1916
return
2017

21-
def decode(self):
22-
uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs(
23-
g_infer_context.infer_req_ids
24-
)
25-
assert len(prefill_reqs) == 0
26-
27-
self._filter_reqs(aborted_reqs)
18+
def init_custom(self):
19+
DecodeNode.init_custom(self)
20+
return
2821

29-
max_decode_num = self._dp_all_reduce_decode_req_num(decode_reqs=decode_reqs)
30-
if max_decode_num != 0:
31-
if not self.enable_decode_microbatch_overlap:
32-
DPChunkedPrefillBackend.normal_decode(self, decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs)
33-
else:
34-
DPChunkedPrefillBackend.overlap_decode(self, decode_reqs, max_decode_num, uninit_reqs, ok_finished_reqs)
22+
def _init_reqs(self, reqs: List[Tuple]):
23+
DecodeNode._init_reqs(self, reqs=reqs)
24+
return
3525

36-
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
26+
def _post_init_reqs(self, uninit_reqs: List[InferReq]):
27+
DecodeNode._post_init_reqs(self, uninit_reqs=uninit_reqs)
3728
return

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_mtp.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl_mtp_for_dp.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

0 commit comments

Comments
 (0)