Skip to content

Commit 6770350

Browse files
author
wangzaijun
committed
fix
1 parent 65a840e commit 6770350

File tree

6 files changed

+131
-83
lines changed

6 files changed

+131
-83
lines changed

lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_kv_move_manager.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
logger = init_logger(__name__)
1515

16+
1617
def start_decode_kv_move_manager_process(args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
1718
event = mp.Event()
1819
proc = mp.Process(target=_init_env, args=(args, info_queue, mem_queues, event))
@@ -22,35 +23,50 @@ def start_decode_kv_move_manager_process(args, info_queue: mp.Queue, mem_queues:
2223
logger.info("decode kv move manager process started")
2324
return
2425

26+
2527
def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event):
2628
import lightllm.utils.rpyc_fix_utils as _
2729

2830
# 注册graceful 退出的处理
2931
graceful_registry(inspect.currentframe().f_code.co_name)
30-
32+
3133
from .up_status import start_up_kv_status_process
3234

3335
up_status_in_queue = mp.SimpleQueue()
3436
start_up_kv_status_process(args, up_status_in_queue)
3537

3638
from .decode_trans_process import start_decode_trans_process
37-
manager = DecodeKVMoveManager(args=args,
38-
info_queue=info_queue,
39-
mem_queues=mem_queues,
40-
start_trans_process_func=start_decode_trans_process,
41-
up_status_in_queue=up_status_in_queue)
39+
40+
manager = DecodeKVMoveManager(
41+
args=args,
42+
info_queue=info_queue,
43+
mem_queues=mem_queues,
44+
start_trans_process_func=start_decode_trans_process,
45+
up_status_in_queue=up_status_in_queue,
46+
)
47+
assert manager is not None
4248
event.set()
43-
while True: time.sleep(100)
49+
while True:
50+
time.sleep(100)
4451
return
4552

4653

4754
class DecodeKVMoveManager(BaseKVMoveManager):
48-
def __init__(self, args: StartArgs, info_queue: mp.Queue, mem_queues: List[mp.Queue], start_trans_process_func: Callable, up_status_in_queue: mp.SimpleQueue):
49-
super().__init__(args=args,
50-
info_queue=info_queue,
51-
mem_queues=mem_queues,
52-
start_trans_process_func=start_trans_process_func,
53-
up_status_in_queue=up_status_in_queue)
55+
def __init__(
56+
self,
57+
args: StartArgs,
58+
info_queue: mp.Queue,
59+
mem_queues: List[mp.Queue],
60+
start_trans_process_func: Callable,
61+
up_status_in_queue: mp.SimpleQueue,
62+
):
63+
super().__init__(
64+
args=args,
65+
info_queue=info_queue,
66+
mem_queues=mem_queues,
67+
start_trans_process_func=start_trans_process_func,
68+
up_status_in_queue=up_status_in_queue,
69+
)
5470
return
5571

5672
# ==================================================================================
@@ -60,12 +76,14 @@ def __init__(self, args: StartArgs, info_queue: mp.Queue, mem_queues: List[mp.Qu
6076
def task_dispatcher_loop(self):
6177
# 获取任务,并分发给相关卡的处理队列
6278
while True:
63-
task_group:NIXLChunckedTransTaskGroup = self.info_queue.get()
79+
task_group: NIXLChunckedTransTaskGroup = self.info_queue.get()
6480
device_id = task_group.task_list[0].dst_device_id
6581
try:
6682
trans_process: KVTransProcess = self.kv_trans_processes[device_id]
6783
trans_process.task_in_queue.put(task_group)
68-
logger.info(f"kv move manager dispatch task group {task_group.task_list[0].to_str()} to device {device_id}")
84+
logger.info(
85+
f"kv move manager dispatch task group {task_group.task_list[0].to_str()} to device {device_id}"
86+
)
6987

7088
except BaseException as e:
71-
logger.exception(str(e))
89+
logger.exception(str(e))

lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py

Lines changed: 62 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
from typing import List, Dict, Union, Deque, Optional
1010
from lightllm.utils.log_utils import init_logger
1111
from lightllm.common.mem_manager import MemoryManager
12-
from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NIXLChunckedTransTaskGroup, NIXLChunckedTransTaskRet, NixlUpKVStatus
12+
from lightllm.server.pd_io_struct import (
13+
NIXLChunckedTransTask,
14+
NIXLChunckedTransTaskGroup,
15+
NIXLChunckedTransTaskRet,
16+
NixlUpKVStatus,
17+
)
1318
from lightllm.server.pd_io_struct import NIXLDecodeNodeInfo
1419
from lightllm.utils.device_utils import kv_trans_use_p2p
1520
from lightllm.utils.graceful_utils import graceful_registry
@@ -28,7 +33,9 @@ def start_decode_trans_process(
2833
mem_queues: List[mp.Queue],
2934
up_status_in_queue: Optional[mp.SimpleQueue],
3035
):
31-
proc = mp.Process(target=_init_env, args=(args, device_id, task_in_queue, task_out_queue, mem_queues, up_status_in_queue))
36+
proc = mp.Process(
37+
target=_init_env, args=(args, device_id, task_in_queue, task_out_queue, mem_queues, up_status_in_queue)
38+
)
3239
proc.start()
3340
assert proc.is_alive()
3441
logger.info(f"prefill trans kv process for device: {device_id} started!")
@@ -53,13 +60,18 @@ def _init_env(
5360
mem_managers: List[MemoryManager] = [mem_queue.get(timeout=60) for mem_queue in mem_queues]
5461
task_out_queue.put("get_mem_managers_ok")
5562

56-
manager = _DecodeTransModule(args=args,
57-
device_id=device_id,
58-
task_in_queue=task_in_queue,
59-
task_out_queue=task_out_queue,
60-
mem_managers=mem_managers,
61-
up_status_in_queue=up_status_in_queue)
62-
while True: time.sleep(100)
63+
manager = _DecodeTransModule(
64+
args=args,
65+
device_id=device_id,
66+
task_in_queue=task_in_queue,
67+
task_out_queue=task_out_queue,
68+
mem_managers=mem_managers,
69+
up_status_in_queue=up_status_in_queue,
70+
)
71+
assert manager is not None
72+
73+
while True:
74+
time.sleep(100)
6375

6476
except Exception as e:
6577
logger.exception(str(e))
@@ -75,7 +87,8 @@ def __init__(
7587
task_in_queue: mp.Queue,
7688
task_out_queue: mp.Queue,
7789
mem_managers: List[MemoryManager],
78-
up_status_in_queue: Optional[mp.SimpleQueue]):
90+
up_status_in_queue: Optional[mp.SimpleQueue],
91+
):
7992
self.args = args
8093
self.dp_world_size = self.args.tp // self.args.dp
8194
self.device_id = device_id
@@ -84,12 +97,13 @@ def __init__(
8497
self.mem_managers = mem_managers
8598
self.up_status_in_queue = up_status_in_queue
8699
cur_mem_manager: MemoryManager = self.mem_managers[device_id]
87-
kv_move_buffer = cur_mem_manager.alloc_paged_kv_move_buffer(page_num=self.args.nixl_pd_kv_page_num,
88-
page_size=self.args.nixl_pd_kv_page_size)
100+
kv_move_buffer = cur_mem_manager.alloc_paged_kv_move_buffer(
101+
page_num=self.args.nixl_pd_kv_page_num, page_size=self.args.nixl_pd_kv_page_size
102+
)
89103
self.copy_cuda_stream = torch.cuda.Stream()
90-
self.transporter = NixlKVTransporter(node_id=self.args.pd_node_id,
91-
tp_idx=device_id,
92-
kv_move_buffer=kv_move_buffer)
104+
self.transporter = NixlKVTransporter(
105+
node_id=self.args.pd_node_id, tp_idx=device_id, kv_move_buffer=kv_move_buffer
106+
)
93107
self.waiting_dict_lock = threading.Lock()
94108
self.waiting_dict: Dict[str, NIXLChunckedTransTask] = {}
95109
self.read_peer_kv_queue = queue.Queue()
@@ -102,11 +116,19 @@ def __init__(
102116
self.page_index_queue = queue.Queue()
103117
for page_index in range(self.args.nixl_pd_kv_page_num):
104118
self.page_index_queue.put(page_index)
105-
106-
for func in [self.recv_task_loop, self.accept_peer_task_loop, self.read_peer_kv_loop, self.update_task_status_loop, self.read_page_to_mems_loop, self.success_loop, self.fail_loop]:
119+
120+
for func in [
121+
self.recv_task_loop,
122+
self.accept_peer_task_loop,
123+
self.read_peer_kv_loop,
124+
self.update_task_status_loop,
125+
self.read_page_to_mems_loop,
126+
self.success_loop,
127+
self.fail_loop,
128+
]:
107129
threading.Thread(target=func, daemon=True).start()
108130
return
109-
131+
110132
@log_exception
111133
def recv_task_loop(self):
112134
while True:
@@ -119,7 +141,7 @@ def recv_task_loop(self):
119141
else:
120142
task.start_trans_time = time.time()
121143
self.success_queue.put((None, task))
122-
144+
123145
# up status
124146
task = trans_task_group.task_list[0]
125147

@@ -137,7 +159,7 @@ def recv_task_loop(self):
137159
up_status = NixlUpKVStatus(
138160
group_request_id=task.request_id,
139161
pd_master_node_id=task.pd_master_node_id,
140-
nixl_params=pickle.dumps(decode_node_info)
162+
nixl_params=pickle.dumps(decode_node_info),
141163
)
142164

143165
self.up_status_in_queue.put(up_status)
@@ -151,7 +173,7 @@ def accept_peer_task_loop(
151173
if len(self.waiting_dict) == 0:
152174
time.sleep(0.003)
153175
continue
154-
176+
155177
# notify update
156178
try:
157179
notifies_dict = self.transporter.get_new_notifs()
@@ -167,17 +189,20 @@ def accept_peer_task_loop(
167189
notify_obj = pickle.loads(notify)
168190
except:
169191
notify_obj = None
170-
192+
171193
if isinstance(notify_obj, NIXLChunckedTransTask):
172194
remote_trans_task = notify_obj
173195
key = remote_trans_task.get_key()
174196
logger.info(f"recv peer trans task {remote_trans_task.to_str()}")
175197
with self.waiting_dict_lock:
176-
local_trans_task : NIXLChunckedTransTask = self.waiting_dict.pop(key, None)
177-
198+
local_trans_task: NIXLChunckedTransTask = self.waiting_dict.pop(key, None)
199+
178200
if local_trans_task is None:
179201
remote_trans_task.error_info = "peer not find"
180-
self.transporter.send_notify_to_prefill_node(prefill_agent_name=remote_agent_name, notify=pickle.dumps(remote_trans_task.createRetObj()))
202+
self.transporter.send_notify_to_prefill_node(
203+
prefill_agent_name=remote_agent_name,
204+
notify=pickle.dumps(remote_trans_task.createRetObj()),
205+
)
181206
else:
182207
local_trans_task.nixl_src_page_index = remote_trans_task.nixl_src_page_index
183208

@@ -189,17 +214,16 @@ def accept_peer_task_loop(
189214
self.read_peer_kv_queue.put(local_trans_task)
190215

191216
self._check_tasks_time_out()
192-
193217

194218
def _check_tasks_time_out(self):
195219
# check time_out update
196220
with self.waiting_dict_lock:
197221
keys = list(self.waiting_dict.keys())
198-
222+
199223
for key in keys:
200224
with self.waiting_dict_lock:
201225
trans_task = self.waiting_dict.pop(key, None)
202-
226+
203227
if trans_task is not None and trans_task.time_out():
204228
trans_task.error_info = "time out in accept_peer_task_loop"
205229
self.failed_queue.put(trans_task)
@@ -209,7 +233,6 @@ def _check_tasks_time_out(self):
209233
with self.waiting_dict_lock:
210234
self.waiting_dict[trans_task.get_key()] = trans_task
211235
return
212-
213236

214237
@log_exception
215238
def read_peer_kv_loop(self):
@@ -224,7 +247,7 @@ def read_peer_kv_loop(self):
224247
local_trans_task.error_info = "time out in read_peer_kv_loop"
225248
self.failed_queue.put(local_trans_task)
226249
continue
227-
250+
228251
try:
229252
xfer_handle = self.transporter.read_blocks_paged(trans_task=local_trans_task)
230253
local_trans_task.xfer_handle = xfer_handle
@@ -239,7 +262,6 @@ def read_peer_kv_loop(self):
239262
self.failed_queue.put(local_trans_task)
240263
continue
241264

242-
243265
@log_exception
244266
def update_task_status_loop(
245267
self,
@@ -253,7 +275,7 @@ def update_task_status_loop(
253275
with self.update_status_task_list_lock:
254276
trans_taskes = self.update_status_task_list.copy()
255277
self.update_status_task_list.clear()
256-
278+
257279
for trans_task in trans_taskes:
258280
ret = self.transporter.check_task_status(trans_task=trans_task)
259281
if ret == "DONE":
@@ -263,7 +285,7 @@ def update_task_status_loop(
263285
trans_task.error_info = "xfer error"
264286
self.failed_queue.put(trans_task)
265287
continue
266-
288+
267289
if trans_task.time_out():
268290
trans_task.error_info = "time out"
269291
self.failed_queue.put(trans_task)
@@ -272,7 +294,6 @@ def update_task_status_loop(
272294
with self.update_status_task_list_lock:
273295
self.update_status_task_list.append(trans_task)
274296

275-
276297
@log_exception
277298
def read_page_to_mems_loop(self):
278299
torch.cuda.set_device(self.device_id)
@@ -286,14 +307,14 @@ def read_page_to_mems_loop(self):
286307
page_index=trans_task.nixl_dst_page_index,
287308
dp_index=trans_task.decode_dp_index,
288309
mem_managers=self.mem_managers,
289-
dp_world_size=self.dp_world_size
310+
dp_world_size=self.dp_world_size,
290311
)
291312
sync_event = torch.cuda.Event()
292313
sync_event.record()
293314

294315
self.success_queue.put((sync_event, trans_task))
295316
return
296-
317+
297318
@log_exception
298319
def success_loop(self):
299320
torch.cuda.set_device(self.device_id)
@@ -304,17 +325,17 @@ def success_loop(self):
304325
# 兼容传输kv 数量为0的时候, sync_event 为 None的情况。
305326
if sync_event is not None:
306327
sync_event.synchronize()
307-
328+
308329
if trans_task.nixl_dst_page_index is not None:
309330
self.page_index_queue.put(trans_task.nixl_dst_page_index)
310-
331+
311332
if trans_task.xfer_handle is not None:
312333
self.transporter.release_xfer_handle(trans_task.xfer_handle)
313-
334+
314335
ret = trans_task.createRetObj()
315336
self.task_out_queue.put(ret)
316337
logger.info(f"trans task ret success:{ret} cost time: {trans_task.transfer_time()} s")
317-
338+
318339
@log_exception
319340
def fail_loop(self):
320341
torch.cuda.set_device(self.device_id)
@@ -328,4 +349,4 @@ def fail_loop(self):
328349
self.transporter.release_xfer_handle(trans_task.xfer_handle)
329350
ret = trans_task.createRetObj()
330351
self.task_out_queue.put(ret)
331-
logger.info(f"trans task ret fail:{ret}")
352+
logger.info(f"trans task ret fail:{ret}")

0 commit comments

Comments
 (0)