99from typing import List , Dict , Union , Deque , Optional
1010from lightllm .utils .log_utils import init_logger
1111from lightllm .common .mem_manager import MemoryManager
12- from lightllm .server .pd_io_struct import NIXLChunckedTransTask , NIXLChunckedTransTaskGroup , NIXLChunckedTransTaskRet , NixlUpKVStatus
12+ from lightllm .server .pd_io_struct import (
13+ NIXLChunckedTransTask ,
14+ NIXLChunckedTransTaskGroup ,
15+ NIXLChunckedTransTaskRet ,
16+ NixlUpKVStatus ,
17+ )
1318from lightllm .server .pd_io_struct import NIXLDecodeNodeInfo
1419from lightllm .utils .device_utils import kv_trans_use_p2p
1520from lightllm .utils .graceful_utils import graceful_registry
@@ -28,7 +33,9 @@ def start_decode_trans_process(
2833 mem_queues : List [mp .Queue ],
2934 up_status_in_queue : Optional [mp .SimpleQueue ],
3035):
31- proc = mp .Process (target = _init_env , args = (args , device_id , task_in_queue , task_out_queue , mem_queues , up_status_in_queue ))
36+ proc = mp .Process (
37+ target = _init_env , args = (args , device_id , task_in_queue , task_out_queue , mem_queues , up_status_in_queue )
38+ )
3239 proc .start ()
3340 assert proc .is_alive ()
3441 logger .info (f"prefill trans kv process for device: { device_id } started!" )
@@ -53,13 +60,18 @@ def _init_env(
5360 mem_managers : List [MemoryManager ] = [mem_queue .get (timeout = 60 ) for mem_queue in mem_queues ]
5461 task_out_queue .put ("get_mem_managers_ok" )
5562
56- manager = _DecodeTransModule (args = args ,
57- device_id = device_id ,
58- task_in_queue = task_in_queue ,
59- task_out_queue = task_out_queue ,
60- mem_managers = mem_managers ,
61- up_status_in_queue = up_status_in_queue )
62- while True : time .sleep (100 )
63+ manager = _DecodeTransModule (
64+ args = args ,
65+ device_id = device_id ,
66+ task_in_queue = task_in_queue ,
67+ task_out_queue = task_out_queue ,
68+ mem_managers = mem_managers ,
69+ up_status_in_queue = up_status_in_queue ,
70+ )
71+ assert manager is not None
72+
73+ while True :
74+ time .sleep (100 )
6375
6476 except Exception as e :
6577 logger .exception (str (e ))
@@ -75,7 +87,8 @@ def __init__(
7587 task_in_queue : mp .Queue ,
7688 task_out_queue : mp .Queue ,
7789 mem_managers : List [MemoryManager ],
78- up_status_in_queue : Optional [mp .SimpleQueue ]):
90+ up_status_in_queue : Optional [mp .SimpleQueue ],
91+ ):
7992 self .args = args
8093 self .dp_world_size = self .args .tp // self .args .dp
8194 self .device_id = device_id
@@ -84,12 +97,13 @@ def __init__(
8497 self .mem_managers = mem_managers
8598 self .up_status_in_queue = up_status_in_queue
8699 cur_mem_manager : MemoryManager = self .mem_managers [device_id ]
87- kv_move_buffer = cur_mem_manager .alloc_paged_kv_move_buffer (page_num = self .args .nixl_pd_kv_page_num ,
88- page_size = self .args .nixl_pd_kv_page_size )
100+ kv_move_buffer = cur_mem_manager .alloc_paged_kv_move_buffer (
101+ page_num = self .args .nixl_pd_kv_page_num , page_size = self .args .nixl_pd_kv_page_size
102+ )
89103 self .copy_cuda_stream = torch .cuda .Stream ()
90- self .transporter = NixlKVTransporter (node_id = self . args . pd_node_id ,
91- tp_idx = device_id ,
92- kv_move_buffer = kv_move_buffer )
104+ self .transporter = NixlKVTransporter (
105+ node_id = self . args . pd_node_id , tp_idx = device_id , kv_move_buffer = kv_move_buffer
106+ )
93107 self .waiting_dict_lock = threading .Lock ()
94108 self .waiting_dict : Dict [str , NIXLChunckedTransTask ] = {}
95109 self .read_peer_kv_queue = queue .Queue ()
@@ -102,11 +116,19 @@ def __init__(
102116 self .page_index_queue = queue .Queue ()
103117 for page_index in range (self .args .nixl_pd_kv_page_num ):
104118 self .page_index_queue .put (page_index )
105-
106- for func in [self .recv_task_loop , self .accept_peer_task_loop , self .read_peer_kv_loop , self .update_task_status_loop , self .read_page_to_mems_loop , self .success_loop , self .fail_loop ]:
119+
120+ for func in [
121+ self .recv_task_loop ,
122+ self .accept_peer_task_loop ,
123+ self .read_peer_kv_loop ,
124+ self .update_task_status_loop ,
125+ self .read_page_to_mems_loop ,
126+ self .success_loop ,
127+ self .fail_loop ,
128+ ]:
107129 threading .Thread (target = func , daemon = True ).start ()
108130 return
109-
131+
110132 @log_exception
111133 def recv_task_loop (self ):
112134 while True :
@@ -119,7 +141,7 @@ def recv_task_loop(self):
119141 else :
120142 task .start_trans_time = time .time ()
121143 self .success_queue .put ((None , task ))
122-
144+
123145 # up status
124146 task = trans_task_group .task_list [0 ]
125147
@@ -137,7 +159,7 @@ def recv_task_loop(self):
137159 up_status = NixlUpKVStatus (
138160 group_request_id = task .request_id ,
139161 pd_master_node_id = task .pd_master_node_id ,
140- nixl_params = pickle .dumps (decode_node_info )
162+ nixl_params = pickle .dumps (decode_node_info ),
141163 )
142164
143165 self .up_status_in_queue .put (up_status )
@@ -151,7 +173,7 @@ def accept_peer_task_loop(
151173 if len (self .waiting_dict ) == 0 :
152174 time .sleep (0.003 )
153175 continue
154-
176+
155177 # notify update
156178 try :
157179 notifies_dict = self .transporter .get_new_notifs ()
@@ -167,17 +189,20 @@ def accept_peer_task_loop(
167189 notify_obj = pickle .loads (notify )
168190 except :
169191 notify_obj = None
170-
192+
171193 if isinstance (notify_obj , NIXLChunckedTransTask ):
172194 remote_trans_task = notify_obj
173195 key = remote_trans_task .get_key ()
174196 logger .info (f"recv peer trans task { remote_trans_task .to_str ()} " )
175197 with self .waiting_dict_lock :
176- local_trans_task : NIXLChunckedTransTask = self .waiting_dict .pop (key , None )
177-
198+ local_trans_task : NIXLChunckedTransTask = self .waiting_dict .pop (key , None )
199+
178200 if local_trans_task is None :
179201 remote_trans_task .error_info = "peer not find"
180- self .transporter .send_notify_to_prefill_node (prefill_agent_name = remote_agent_name , notify = pickle .dumps (remote_trans_task .createRetObj ()))
202+ self .transporter .send_notify_to_prefill_node (
203+ prefill_agent_name = remote_agent_name ,
204+ notify = pickle .dumps (remote_trans_task .createRetObj ()),
205+ )
181206 else :
182207 local_trans_task .nixl_src_page_index = remote_trans_task .nixl_src_page_index
183208
@@ -189,17 +214,16 @@ def accept_peer_task_loop(
189214 self .read_peer_kv_queue .put (local_trans_task )
190215
191216 self ._check_tasks_time_out ()
192-
193217
194218 def _check_tasks_time_out (self ):
195219 # check time_out update
196220 with self .waiting_dict_lock :
197221 keys = list (self .waiting_dict .keys ())
198-
222+
199223 for key in keys :
200224 with self .waiting_dict_lock :
201225 trans_task = self .waiting_dict .pop (key , None )
202-
226+
203227 if trans_task is not None and trans_task .time_out ():
204228 trans_task .error_info = "time out in accept_peer_task_loop"
205229 self .failed_queue .put (trans_task )
@@ -209,7 +233,6 @@ def _check_tasks_time_out(self):
209233 with self .waiting_dict_lock :
210234 self .waiting_dict [trans_task .get_key ()] = trans_task
211235 return
212-
213236
214237 @log_exception
215238 def read_peer_kv_loop (self ):
@@ -224,7 +247,7 @@ def read_peer_kv_loop(self):
224247 local_trans_task .error_info = "time out in read_peer_kv_loop"
225248 self .failed_queue .put (local_trans_task )
226249 continue
227-
250+
228251 try :
229252 xfer_handle = self .transporter .read_blocks_paged (trans_task = local_trans_task )
230253 local_trans_task .xfer_handle = xfer_handle
@@ -239,7 +262,6 @@ def read_peer_kv_loop(self):
239262 self .failed_queue .put (local_trans_task )
240263 continue
241264
242-
243265 @log_exception
244266 def update_task_status_loop (
245267 self ,
@@ -253,7 +275,7 @@ def update_task_status_loop(
253275 with self .update_status_task_list_lock :
254276 trans_taskes = self .update_status_task_list .copy ()
255277 self .update_status_task_list .clear ()
256-
278+
257279 for trans_task in trans_taskes :
258280 ret = self .transporter .check_task_status (trans_task = trans_task )
259281 if ret == "DONE" :
@@ -263,7 +285,7 @@ def update_task_status_loop(
263285 trans_task .error_info = "xfer error"
264286 self .failed_queue .put (trans_task )
265287 continue
266-
288+
267289 if trans_task .time_out ():
268290 trans_task .error_info = "time out"
269291 self .failed_queue .put (trans_task )
@@ -272,7 +294,6 @@ def update_task_status_loop(
272294 with self .update_status_task_list_lock :
273295 self .update_status_task_list .append (trans_task )
274296
275-
276297 @log_exception
277298 def read_page_to_mems_loop (self ):
278299 torch .cuda .set_device (self .device_id )
@@ -286,14 +307,14 @@ def read_page_to_mems_loop(self):
286307 page_index = trans_task .nixl_dst_page_index ,
287308 dp_index = trans_task .decode_dp_index ,
288309 mem_managers = self .mem_managers ,
289- dp_world_size = self .dp_world_size
310+ dp_world_size = self .dp_world_size ,
290311 )
291312 sync_event = torch .cuda .Event ()
292313 sync_event .record ()
293314
294315 self .success_queue .put ((sync_event , trans_task ))
295316 return
296-
317+
297318 @log_exception
298319 def success_loop (self ):
299320 torch .cuda .set_device (self .device_id )
@@ -304,17 +325,17 @@ def success_loop(self):
304325 # 兼容传输kv 数量为0的时候, sync_event 为 None的情况。
305326 if sync_event is not None :
306327 sync_event .synchronize ()
307-
328+
308329 if trans_task .nixl_dst_page_index is not None :
309330 self .page_index_queue .put (trans_task .nixl_dst_page_index )
310-
331+
311332 if trans_task .xfer_handle is not None :
312333 self .transporter .release_xfer_handle (trans_task .xfer_handle )
313-
334+
314335 ret = trans_task .createRetObj ()
315336 self .task_out_queue .put (ret )
316337 logger .info (f"trans task ret success:{ ret } cost time: { trans_task .transfer_time ()} s" )
317-
338+
318339 @log_exception
319340 def fail_loop (self ):
320341 torch .cuda .set_device (self .device_id )
@@ -328,4 +349,4 @@ def fail_loop(self):
328349 self .transporter .release_xfer_handle (trans_task .xfer_handle )
329350 ret = trans_task .createRetObj ()
330351 self .task_out_queue .put (ret )
331- logger .info (f"trans task ret fail:{ ret } " )
352+ logger .info (f"trans task ret fail:{ ret } " )
0 commit comments