Skip to content

Commit d7d39a8

Browse files
author
wangzaijun
committed
fix
1 parent 6450d72 commit d7d39a8

22 files changed

+2250
-0
lines changed

lightllm/server/pd_io_struct.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,41 @@ def get_cost_time(self):
205205
class KVMoveTaskGroup:
206206
tasks: List[KVMoveTask]
207207
connect_id: str
208+
209+
210+
####### 下边是 NIXL模式下使用的特定对象 ########
211+
212+
213+
@dataclass
214+
class NIXLChunckedTransTask:
215+
request_id: int
216+
dp_index: int
217+
trans_device_id: int # 当前设备使用的传输设备id,对应第几张显卡。
218+
start_kv_index: int
219+
end_kv_index: int
220+
mem_indexes: List[int]
221+
is_last_chunk: bool
222+
223+
def __post_init__(self):
224+
if self.start_kv_index < 0 or self.end_kv_index <= self.start_kv_index:
225+
error_info = "start_kv_index must >=0 and end_kv_index > start_kv_index"
226+
logger.error(error_info)
227+
raise ValueError(error_info)
228+
if len(self.mem_indexes) == 0:
229+
error_info = "mem_indexes must len > 0"
230+
logger.error(error_info)
231+
raise ValueError(error_info)
232+
assert len(self.mem_indexes) == (self.end_kv_index - self.start_kv_index)
233+
return
234+
235+
236+
@dataclass
237+
class PrefillTransTaskRet:
238+
request_id: int
239+
is_error: bool
240+
error_info: str = None
241+
242+
243+
@dataclass
244+
class NIXLStopTransTask:
245+
request_id: int

lightllm/server/router/model_infer/mode_backend/pd_nixl/pd_mode/__init__.py

Whitespace-only changes.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .decode_kv_move_manager import start_decode_kv_move_manager_process
2+
from .decode_trans_process import start_decode_trans_process
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import os
2+
import torch
3+
import torch.multiprocessing as mp
4+
import torch.distributed as dist
5+
import threading
6+
from lightllm.server.router.model_infer.mode_backend.chunked_prefill.impl import ChunkedPrefillBackend
7+
from typing import List, Tuple
8+
from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq, g_infer_state_lock
9+
from lightllm.server.core.objs import FinishStatus
10+
from lightllm.utils.log_utils import init_logger
11+
from rpyc.utils.server import ThreadedServer
12+
from lightllm.common.basemodel.infer_lock import g_router_lock
13+
from .decode_task_cache import g_success_kv_move_task_cache, KVMoveTask
14+
from lightllm.utils.device_utils import kv_trans_use_p2p
15+
from lightllm.utils.envs_utils import get_unique_server_name
16+
from lightllm.utils.dist_utils import create_new_group_for_current_dp
17+
18+
logger = init_logger(__name__)
19+
20+
21+
class DecodeNode(ChunkedPrefillBackend):
22+
def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None:
23+
super().__init__()
24+
self.info_queue: mp.Queue = info_queue
25+
self.mem_queue: mp.Queue = mem_queue
26+
self.classed_req_strict_prefill = False
27+
28+
def init_custom(self):
29+
30+
self.lock_nccl_group = create_new_group_for_current_dp("gloo")
31+
logger.info(f"lock_nccl_group ranks {dist.get_rank(self.lock_nccl_group)}")
32+
33+
from .decode_infer_rpyc import PDDecodeInferRpcServer
34+
35+
socket_path = f"/tmp/{get_unique_server_name()}_decode_node_infer_rpyc_{self.pd_rpyc_ports[self.rank_in_node]}"
36+
if os.path.exists(socket_path):
37+
os.remove(socket_path)
38+
39+
t = ThreadedServer(
40+
PDDecodeInferRpcServer(self), socket_path=socket_path, protocol_config={"allow_pickle": True}
41+
)
42+
threading.Thread(target=lambda: t.start(), daemon=True).start()
43+
44+
if kv_trans_use_p2p():
45+
from ..p2p_fix import reduce_tensor
46+
47+
mp.reductions.reduce_tensor.__code__ = reduce_tensor.__code__
48+
49+
return
50+
51+
def _init_reqs(self, reqs: List[Tuple]):
52+
"""
53+
替换请求初始化操作,替换为 Decode 节点独有的一些特殊初始化流程
54+
"""
55+
if self.dp_size_in_node != 1:
56+
dp_rank_in_node = self.dp_rank_in_node
57+
reqs = [req for req in reqs if req[3] == dp_rank_in_node]
58+
59+
g_infer_state_lock.acquire()
60+
61+
uninit_reqs = g_infer_context.add_reqs(reqs, init_prefix_cache=False)
62+
# 匹配radix cache,并更新一些资源的管理。
63+
self._post_init_reqs(uninit_reqs=uninit_reqs)
64+
65+
g_infer_state_lock.release()
66+
req_ids = [e[0] for e in reqs]
67+
return req_ids
68+
69+
def _post_init_reqs(self, uninit_reqs: List[InferReq]):
70+
"""
71+
检查请求的 kv len 将可能有问题的请求立即结束掉
72+
"""
73+
if len(uninit_reqs) == 0:
74+
return
75+
76+
remove_count = 0
77+
estimated_peak_token_count = 0
78+
for req_obj in uninit_reqs:
79+
req_obj: InferReq = req_obj # for easy typing
80+
request_id = req_obj.req_id
81+
if request_id in g_success_kv_move_task_cache:
82+
task, share_node, _ = g_success_kv_move_task_cache.pop(request_id)
83+
task: KVMoveTask = task # for easy typing
84+
self.radix_cache.dec_node_ref_counter(share_node)
85+
req_all_len = len(task.input_tokens) + task.decode_node.max_new_tokens
86+
remove_count += req_all_len
87+
estimated_peak_token_count += req_all_len
88+
req_obj._match_radix_cache()
89+
else:
90+
# 对于不合法的请求,直接模拟将其finished掉
91+
req_obj.cur_output_len += 1
92+
req_obj.set_next_gen_token_id(0, 0.0, 1)
93+
req_obj.finish_status.set_status(FinishStatus.FINISHED_STOP)
94+
95+
if self.is_master_in_dp:
96+
req_obj.shm_req.shm_cur_kv_len = req_obj.cur_kv_len
97+
req_obj.shm_req.shm_cur_output_len = req_obj.cur_output_len
98+
req_obj.shm_req.finish_token_index = req_obj.get_cur_total_len() - 1
99+
req_obj.shm_req.finish_status.set_status(FinishStatus.FINISHED_STOP)
100+
req_obj.shm_req.candetoken_out_len = req_obj.cur_output_len
101+
102+
req_id = req_obj.shm_req.request_id
103+
logger.error(f"req_id: {req_id} forced to finished, it not in g_success_kv_move_task_cache")
104+
105+
if self.is_master_in_dp:
106+
with g_router_lock.obj:
107+
self.shared_token_load.add_frozened_token_count(-remove_count, self.dp_rank_in_node)
108+
self.shared_token_load.add_estimated_peak_token_count(estimated_peak_token_count, self.dp_rank_in_node)
109+
return
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch.multiprocessing as mp
2+
from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq
3+
from lightllm.utils.log_utils import init_logger
4+
from typing import List, Tuple
5+
from lightllm.server.router.model_infer.mode_backend.dp_backend.impl import DPChunkedPrefillBackend
6+
from .decode_impl import DecodeNode
7+
8+
logger = init_logger(__name__)
9+
10+
11+
class DPForDecodeNode(DPChunkedPrefillBackend):
12+
def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None:
13+
super().__init__()
14+
self.info_queue: mp.Queue = info_queue
15+
self.mem_queue: mp.Queue = mem_queue
16+
self.classed_req_strict_prefill = False
17+
return
18+
19+
def init_custom(self):
20+
DecodeNode.init_custom(self)
21+
return
22+
23+
def _init_reqs(self, reqs: List[Tuple]):
24+
DecodeNode._init_reqs(self, reqs=reqs)
25+
return
26+
27+
def _post_init_reqs(self, uninit_reqs: List[InferReq]):
28+
DecodeNode._post_init_reqs(self, uninit_reqs=uninit_reqs)
29+
return
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import torch
2+
import torch.distributed as dist
3+
import rpyc
4+
import time
5+
from typing import Dict, List, Tuple, Optional, Union
6+
from rpyc.utils.classic import obtain
7+
from .decode_impl import DecodeNode
8+
from lightllm.common.basemodel.infer_lock import acquire_lock_until_ready, release_acquired_lock, g_router_lock
9+
from .decode_task_cache import g_kv_move_task_cache, g_success_kv_move_task_cache
10+
from lightllm.server.pd_io_struct import KVMoveTask
11+
from lightllm.utils.log_utils import init_logger
12+
13+
logger = init_logger(__name__)
14+
15+
16+
class PDDecodeInferRpcServer(rpyc.Service):
17+
def __init__(self, backend: DecodeNode) -> None:
18+
super().__init__()
19+
self.backend = backend
20+
self.device_id = self.backend.current_device_id
21+
self.dp_rank_in_node = self.backend.dp_rank_in_node
22+
self.is_master_in_dp = self.backend.is_master_in_dp
23+
return
24+
25+
def on_connect(self, conn):
26+
torch.cuda.set_device(f"cuda:{self.device_id}")
27+
return
28+
29+
def judge_token_is_ok(self, key_len, max_new_token):
30+
# 多 dp 单卡模式下, 每个 dp 各自处理自己的, 不需要同步
31+
if self.backend.dp_world_size == 1:
32+
with g_router_lock.obj:
33+
shared_token_load = self.backend.shared_token_load
34+
peak_num = shared_token_load.get_estimated_peak_token_count(self.dp_rank_in_node)
35+
peak_num += shared_token_load.get_frozened_token_count(self.dp_rank_in_node)
36+
peak_num += key_len + max_new_token
37+
38+
if peak_num < self.backend.get_max_total_token_num():
39+
object_list = [True]
40+
shared_token_load.add_frozened_token_count(key_len + max_new_token, self.dp_rank_in_node)
41+
else:
42+
object_list = [False]
43+
return object_list[0]
44+
45+
# 普通单dp模式下, 只有主 rank 处理信息,并将数据同步到其他rank上
46+
if self.is_master_in_dp:
47+
with g_router_lock.obj:
48+
shared_token_load = self.backend.shared_token_load
49+
peak_num = shared_token_load.get_estimated_peak_token_count(self.dp_rank_in_node)
50+
peak_num += shared_token_load.get_frozened_token_count(self.dp_rank_in_node)
51+
peak_num += key_len + max_new_token
52+
53+
if peak_num < self.backend.get_max_total_token_num():
54+
object_list = [True]
55+
shared_token_load.add_frozened_token_count(key_len + max_new_token, self.dp_rank_in_node)
56+
else:
57+
object_list = [False]
58+
dist.broadcast_object_list(object_list, src=0, group=self.backend.lock_nccl_group)
59+
else:
60+
object_list = [None]
61+
dist.broadcast_object_list(object_list, src=0, group=self.backend.lock_nccl_group)
62+
return object_list[0]
63+
64+
def recover_frozen_token(self, key_len, max_new_token):
65+
if self.is_master_in_dp:
66+
with g_router_lock.obj:
67+
shared_token_load = self.backend.shared_token_load
68+
shared_token_load.add_frozened_token_count(-(key_len + max_new_token), self.dp_rank_in_node)
69+
return
70+
71+
def _alloc_to_frozen_some_tokens(self, move_task: KVMoveTask):
72+
is_ok = self.judge_token_is_ok(len(move_task.input_tokens), move_task.decode_node.max_new_tokens)
73+
if not is_ok:
74+
if self.is_master_in_dp:
75+
logger.info(f"req_id: {move_task.to_decode_log_info()} alloc token failed")
76+
shared_token_load = self.backend.shared_token_load
77+
dp_rank = self.dp_rank_in_node
78+
frozen_token_num = shared_token_load.get_frozened_token_count(dp_rank)
79+
estimated_peak_token_num = shared_token_load.get_estimated_peak_token_count(dp_rank)
80+
logger.debug(
81+
f"radix refed token num {self.backend.radix_cache.get_refed_tokens_num()}\n"
82+
f"radix hold token num {self.backend.radix_cache.get_tree_total_tokens_num()}\n"
83+
f"mem manager can alloc token num {self.backend.model.mem_manager.can_use_mem_size}\n"
84+
f"mem manager total size {self.backend.model.mem_manager.size}"
85+
f"frozened token num {frozen_token_num}\n"
86+
f"estimated peak token num {estimated_peak_token_num}\n"
87+
)
88+
return None
89+
90+
key = torch.tensor(move_task.input_tokens, dtype=torch.int64, device="cpu")
91+
tree_node, kv_len, fused_token_indexes = self.backend.radix_cache.match_prefix(key, update_refs=True)
92+
# 如果没匹配到,说明长度是0, 将fused_token_indexes做一下转换
93+
fused_token_indexes = [] if fused_token_indexes is None else fused_token_indexes.tolist()
94+
need_len = len(move_task.input_tokens) - kv_len
95+
if need_len == 0:
96+
alloc_token_indexes = []
97+
else:
98+
self.backend.radix_cache.free_radix_cache_to_get_enough_token(need_len)
99+
alloc_token_indexes = self.backend.model.mem_manager.alloc(need_len)
100+
if alloc_token_indexes is not None:
101+
alloc_token_indexes = alloc_token_indexes.tolist()
102+
103+
if alloc_token_indexes is None:
104+
self.backend.radix_cache.dec_node_ref_counter(tree_node)
105+
self.recover_frozen_token(len(move_task.input_tokens), move_task.decode_node.max_new_tokens)
106+
return None
107+
108+
move_task.decode_token_indexes = alloc_token_indexes
109+
move_task.move_kv_len = need_len
110+
111+
g_kv_move_task_cache[move_task.group_request_id] = (move_task, tree_node, fused_token_indexes)
112+
return move_task.decode_token_indexes
113+
114+
# 返回 None 代表服务繁忙已经无法调度新的请求进入了
115+
def exposed_alloc_to_frozen_some_tokens(self, move_tasks: List[KVMoveTask]) -> List[Optional[List[int]]]:
116+
move_tasks = obtain(move_tasks)
117+
acquire_lock_until_ready(self.backend.lock_nccl_group)
118+
try:
119+
ans_list = []
120+
for move_task in move_tasks:
121+
ans_list.append(self._alloc_to_frozen_some_tokens(move_task))
122+
return ans_list
123+
except BaseException as e:
124+
logger.exception(str(e))
125+
return None
126+
finally:
127+
release_acquired_lock()
128+
129+
def _put_kv_received_to_radix_cache(self, group_req_id: int):
130+
move_task, tree_node, fused_token_indexes = g_kv_move_task_cache.pop(group_req_id)
131+
radix_cache = self.backend.radix_cache
132+
key = torch.tensor(move_task.input_tokens, dtype=torch.int64, device="cpu")
133+
value = torch.tensor(fused_token_indexes + move_task.decode_token_indexes, dtype=torch.int64, device="cpu")
134+
prefix_len = radix_cache.insert(key, value)
135+
assert len(fused_token_indexes) <= prefix_len
136+
self.backend.model.mem_manager.free(value[len(fused_token_indexes) : prefix_len])
137+
self.backend.radix_cache.dec_node_ref_counter(tree_node)
138+
139+
# 申请一段key,把 radix cache 锁住,防止极端情况下被刷掉, decode 端通过减两次引用计数来修正。
140+
tree_node, kv_len, _ = self.backend.radix_cache.match_prefix(key, update_refs=True)
141+
assert len(key) == kv_len
142+
g_success_kv_move_task_cache[group_req_id] = (move_task, tree_node, time.time())
143+
return
144+
145+
def exposed_put_kv_received_to_radix_cache(self, group_req_ids: List[int]):
146+
group_req_ids = obtain(group_req_ids)
147+
acquire_lock_until_ready(self.backend.lock_nccl_group)
148+
for group_req_id in group_req_ids:
149+
self._put_kv_received_to_radix_cache(group_req_id)
150+
release_acquired_lock()
151+
return
152+
153+
def _fail_to_realese_forzen_tokens(self, group_req_id: int):
154+
move_task, tree_node, fused_token_indexes = g_kv_move_task_cache.pop(group_req_id)
155+
value = torch.tensor(move_task.decode_token_indexes, dtype=torch.int64, device="cpu")
156+
self.backend.model.mem_manager.free(value)
157+
self.backend.radix_cache.dec_node_ref_counter(tree_node)
158+
self.recover_frozen_token(len(move_task.input_tokens), move_task.decode_node.max_new_tokens)
159+
return
160+
161+
def exposed_fail_to_realese_forzen_tokens(self, group_req_ids: List[int]):
162+
group_req_ids = obtain(group_req_ids)
163+
acquire_lock_until_ready(self.backend.lock_nccl_group)
164+
for group_req_id in group_req_ids:
165+
self._fail_to_realese_forzen_tokens(group_req_id)
166+
release_acquired_lock()
167+
return
168+
169+
def exposed_put_mem_manager_to_mem_queue(self):
170+
self.backend.mem_queue.put(self.backend.model.mem_manager)
171+
logger.info("put mem manager to info_queues ok")
172+
return
173+
174+
def exposed_unfrozen_time_out_reqs_tokens(self):
175+
acquire_lock_until_ready(self.backend.lock_nccl_group)
176+
if self.backend.dp_world_size == 1:
177+
need_release_reqs = self._get_time_out_reqs()
178+
logger.info(f"kv time out reqs: {need_release_reqs}")
179+
remove_tokens = self._remove_time_out_reqs(need_release_reqs)
180+
if remove_tokens != 0:
181+
with g_router_lock.obj:
182+
self.backend.shared_token_load.add_frozened_token_count(-remove_tokens, self.dp_rank_in_node)
183+
else:
184+
if self.is_master_in_dp:
185+
need_release_reqs = self._get_time_out_reqs()
186+
logger.info(f"kv time out reqs: {need_release_reqs}")
187+
dist.broadcast_object_list([need_release_reqs], src=0, group=self.backend.lock_nccl_group)
188+
else:
189+
receive_objs = [None]
190+
dist.broadcast_object_list(receive_objs, src=0, group=self.backend.lock_nccl_group)
191+
need_release_reqs = receive_objs[0]
192+
remove_tokens = self._remove_time_out_reqs(need_release_reqs)
193+
if self.is_master_in_dp and remove_tokens != 0:
194+
with g_router_lock.obj:
195+
self.backend.shared_token_load.add_frozened_token_count(-remove_tokens, self.dp_rank_in_node)
196+
197+
release_acquired_lock()
198+
return
199+
200+
def _get_time_out_reqs(self):
201+
need_release_reqs = []
202+
for req_id, (_, _, time_mark) in g_success_kv_move_task_cache.items():
203+
# 6s 这个请求都没有被调度使用,就会主动被删除掉锁定,释放其锁定的token
204+
if time.time() - time_mark > 6:
205+
need_release_reqs.append(req_id)
206+
return need_release_reqs
207+
208+
def _remove_time_out_reqs(self, need_release_reqs: List[int]) -> int:
209+
remove_tokens = 0
210+
for req_id in need_release_reqs:
211+
task, tree_node, _ = g_success_kv_move_task_cache.pop(req_id)
212+
self.backend.radix_cache.dec_node_ref_counter(tree_node)
213+
remove_tokens += len(task.input_tokens) + task.decode_node.max_new_tokens
214+
return remove_tokens

0 commit comments

Comments
 (0)