2929thread_local_data = threading .local ()
3030
3131KV_MOVE_MAX_NUM = 16
32+ KV_MOVE_MAX_RESTART_CNT = 3
3233
3334
3435@dataclass
3536class TransProcessObj :
3637 prefill_node_id : int = None
38+ process : mp .Process = None
3739 task_in_queue : mp .Queue = None
3840 task_out_queue : mp .Queue = None
39- prefill_ip : str = None
40- prefill_port : int = None
41+ pd_prefill_nccl_ip : str = None
42+ pd_prefill_nccl_port : int = None
4143 device_index : int = None
4244 manager : "DecodeKVMoveManager" = None
4345 has_error : bool = False
@@ -47,32 +49,36 @@ class TransProcessObj:
4749 put_to_radix_thread : threading .Thread = None
4850 latest_check_time : float = None
4951
50- def create (self , prefill_node_id : str , prefill_ip : str , prefill_port : int , manager : "DecodeKVMoveManager" ):
52+ def create (
53+ self , prefill_node_id : str , pd_prefill_nccl_ip : str , pd_prefill_nccl_port : int , manager : "DecodeKVMoveManager"
54+ ):
5155
5256 device_index = manager .get_next_device_index ()
5357 decode_node_id = manager .args .pd_node_id
5458 task_in_queue = manager .kv_trans_task_in_queues [device_index ]
5559 task_out_queue = manager .kv_trans_task_out_queues [device_index ]
5660
57- task_in_queue .put (
58- PDTransJoinInfo (
59- prefill_id = prefill_node_id ,
60- prefill_device_id = - 1 ,
61- prefill_ip = prefill_ip ,
62- prefill_port = prefill_port ,
63- decode_id = decode_node_id ,
64- decode_device_id = device_index ,
61+ with manager .device_locks [device_index ]:
62+ task_in_queue .put (
63+ PDTransJoinInfo (
64+ prefill_id = prefill_node_id ,
65+ prefill_device_id = - 1 ,
66+ pd_prefill_nccl_ip = pd_prefill_nccl_ip ,
67+ pd_prefill_nccl_port = pd_prefill_nccl_port ,
68+ decode_id = decode_node_id ,
69+ decode_device_id = device_index ,
70+ )
6571 )
66- )
67- assert task_out_queue .get (timeout = 60 ) == "nccl_ok"
72+ assert task_out_queue .get (timeout = 60 ) == "nccl_ok"
6873
6974 self .prefill_node_id = prefill_node_id
7075 self .decode_node_id = decode_node_id
7176 self .task_in_queue = task_in_queue
7277 self .task_out_queue = task_out_queue
73- self .prefill_ip = prefill_ip
74- self .prefill_port = prefill_port
78+ self .pd_prefill_nccl_ip = pd_prefill_nccl_ip
79+ self .pd_prefill_nccl_port = pd_prefill_nccl_port
7580 self .device_index = device_index
81+ self .process = manager .kv_trans_processes [device_index ]
7682
7783 self .manager = manager
7884 self .latest_check_time = time .time ()
@@ -90,6 +96,20 @@ def create(self, prefill_node_id: str, prefill_ip: str, prefill_port: int, manag
9096 self .put_to_radix_thread .start ()
9197 return
9298
99+ def check_trans_process (self , raise_exception = True ):
100+ process = psutil .Process (self .process .pid )
101+ if not (process .is_running () and process .status () != psutil .STATUS_ZOMBIE ):
102+ self .set_has_error ()
103+ if raise_exception :
104+ raise Exception (f"trans process: { self .process .pid } is dead" )
105+ return
106+
107+ def timer_to_check_status (self , raise_exception = True ):
108+ if time .time () - self .latest_check_time >= 2.0 :
109+ self .latest_check_time = time .time ()
110+ self .check_trans_process (raise_exception = raise_exception )
111+ return
112+
93113 def _transfer_kv (self , move_tasks : List [KVMoveTask ]):
94114 with self .manager .device_locks [self .device_index ]:
95115 self .task_in_queue .put (move_tasks .copy (), timeout = 10 )
@@ -120,6 +140,7 @@ def kv_move_loop(self):
120140 logger .info (f"{ func_name } get task { task .to_decode_log_info ()} " )
121141
122142 try :
143+ self .timer_to_check_status (raise_exception = True )
123144 if not kv_trans_use_p2p ():
124145 with self .manager .kv_trans_lock :
125146 self ._transfer_kv (move_tasks )
@@ -150,6 +171,7 @@ def put_to_radix_loop(self):
150171 logger .info (f"{ func_name } get put radix task { task .to_decode_log_info ()} " )
151172
152173 try :
174+ self .timer_to_check_status (raise_exception = True )
153175 # random to check stats
154176 self .manager ._put_kv_received_to_radix_cache (move_tasks .copy ())
155177 for task in move_tasks .copy ():
@@ -266,31 +288,17 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
266288 # 需要每个卡有一个锁来规划每次只能有一个tran obj 操作对应显卡上的传输任务。
267289 self .device_locks = [threading .Lock () for _ in range (self .node_world_size )]
268290
269- from .decode_trans_process import start_decode_trans_process
270-
271291 self .kv_trans_processes = []
272292 self .kv_trans_task_in_queues = []
273293 self .kv_trans_task_out_queues = []
274- self .kv_trans_process_alive = []
275-
276- for device_index in range (self .node_world_size ):
277- kv_trans_task_in_queue = mp .Queue ()
278- kv_trans_task_out_queue = mp .Queue ()
279- kv_trans_process = start_decode_trans_process (
280- self .args ,
281- device_index ,
282- kv_trans_task_in_queue ,
283- kv_trans_task_out_queue ,
284- self .mem_queues ,
285- )
286- assert kv_trans_task_out_queue .get (timeout = 30 ) == "proc_start"
287- self ._put_mem_manager_to_mem_queue ()
288- assert kv_trans_task_out_queue .get (timeout = 60 ) == "get_mem_managers_ok"
294+ self .kv_trans_process_restart_cnt = []
289295
290- self .kv_trans_processes .append (kv_trans_process )
291- self .kv_trans_task_in_queues .append (kv_trans_task_in_queue )
292- self .kv_trans_task_out_queues .append (kv_trans_task_out_queue )
293- self .kv_trans_process_alive .append (True )
296+ for device_id in range (self .node_world_size ):
297+ self .kv_trans_task_in_queues .append (mp .Queue ())
298+ self .kv_trans_task_out_queues .append (mp .Queue ())
299+ self .kv_trans_process_restart_cnt .append (0 )
300+ self .kv_trans_processes .append (None )
301+ assert self .start_trans_process (device_id )
294302
295303 return
296304
@@ -400,17 +408,19 @@ def exposed_check_alive(self):
400408 # 用于 prefill node check 通信连接的状态。
401409 return
402410
403- def exposed_build_trans_process (self , prefill_node_id , prefill_ip , prefill_port , prefill_node_max_kv_trans_num ):
404- prefill_node_id , prefill_ip , prefill_port , prefill_node_max_kv_trans_num = list (
405- map (obtain , [prefill_node_id , prefill_ip , prefill_port , prefill_node_max_kv_trans_num ])
411+ def exposed_build_trans_process (
412+ self , prefill_node_id , pd_prefill_nccl_ip , pd_prefill_nccl_port , prefill_node_max_kv_trans_num
413+ ):
414+ prefill_node_id , pd_prefill_nccl_ip , pd_prefill_nccl_port , prefill_node_max_kv_trans_num = list (
415+ map (obtain , [prefill_node_id , pd_prefill_nccl_ip , pd_prefill_nccl_port , prefill_node_max_kv_trans_num ])
406416 )
407417 thread_local_data .prefill_node_id = prefill_node_id
408418
409- logger .info (f"build trans infos { prefill_node_id } { prefill_ip } { prefill_port } " )
419+ logger .info (f"build trans infos { prefill_node_id } { pd_prefill_nccl_ip } { pd_prefill_nccl_port } " )
410420 # 如果有历史残留,一并移除
411421 self .remove_trans_obj (prefill_node_id )
412422 tran_obj = TransProcessObj ()
413- tran_obj .create (prefill_node_id , prefill_ip , prefill_port , self )
423+ tran_obj .create (prefill_node_id , pd_prefill_nccl_ip , pd_prefill_nccl_port , self )
414424 self .node_id_to_trans_obj [prefill_node_id ] = tran_obj
415425 return min (prefill_node_max_kv_trans_num , self .args .max_total_token_num )
416426
@@ -476,7 +486,7 @@ def exposed_request_data_transfer(self, tasks: List[KVMoveTask]) -> List[Optiona
476486
477487 def get_next_device_index (self ):
478488 counts = [
479- 0 if self .kv_trans_process_alive [ device_id ] else (1 << 20 ) for device_id in range (self .node_world_size )
489+ 0 if self .is_kv_trans_process_alive ( device_id ) else (1 << 20 ) for device_id in range (self .node_world_size )
480490 ]
481491 for obj in self .node_id_to_trans_obj .values ():
482492 counts [obj .device_index ] += 1
@@ -509,16 +519,60 @@ def remove_trans_obj(self, prefill_node_id):
509519 trans_obj .set_has_error ()
510520 return
511521
522+ def remove_trans_obj_by_deviceid (self , device_id ):
523+ for node_id , t_obj in self .node_id_to_trans_obj .items ():
524+ if t_obj .device_index == device_id :
525+ self .remove_dead_trans_obj (node_id )
526+
527+ def start_trans_process (self , device_id : int ):
528+ task_in_queue = self .kv_trans_task_in_queues [device_id ]
529+ task_out_queue = self .kv_trans_task_out_queues [device_id ]
530+ self .kv_trans_process_restart_cnt [device_id ] += 1
531+
532+ if self .kv_trans_processes [device_id ]:
533+ # force kill
534+ try :
535+ self .remove_trans_obj_by_deviceid (device_id )
536+ process = psutil .Process (self .kv_trans_processes [device_id ].pid )
537+ process .kill ()
538+ self .kv_trans_processes [device_id ] = None
539+ except Exception :
540+ pass
541+
542+ try :
543+ from .decode_trans_process import start_decode_trans_process
544+
545+ kv_trans_process = start_decode_trans_process (
546+ self .args ,
547+ device_id ,
548+ task_in_queue ,
549+ task_out_queue ,
550+ self .mem_queues ,
551+ )
552+ assert task_out_queue .get (timeout = 30 ) == "proc_start"
553+ self ._put_mem_manager_to_mem_queue ()
554+ assert task_out_queue .get (timeout = 60 ) == "get_mem_managers_ok"
555+
556+ self .kv_trans_processes [device_id ] = kv_trans_process
557+
558+ return True
559+ except Exception as e :
560+ logger .warning (f"Failed start kv trans process for device { device_id } : { e } " )
561+ return False
562+
563+ def is_kv_trans_process_alive (self , device_id ):
564+ return self .kv_trans_process_restart_cnt [device_id ] <= KV_MOVE_MAX_RESTART_CNT
565+
512566 def check_trans_process (self , raise_exception = True ):
513567 at_least_one_alive = False
514568 for device_id in range (self .node_world_size ):
515- if not self .kv_trans_process_alive [ device_id ] :
569+ if not self .is_kv_trans_process_alive ( device_id ) :
516570 continue
517571
518572 process = psutil .Process (self .kv_trans_processes [device_id ].pid )
519573 if not (process .is_running () and process .status () != psutil .STATUS_ZOMBIE ):
520- self . kv_trans_process_alive [ device_id ] = False
521- logger . error ( f"kv trans process for device: { device_id } dead!!!" )
574+ logger . error ( f"kv trans process for device: { device_id } dead!!!, try start again..." )
575+ self . start_trans_process ( device_id )
522576 else :
523577 at_least_one_alive = True
524578
@@ -530,17 +584,24 @@ def check_trans_process(self, raise_exception=True):
530584
531585 def timer_loop (self ):
532586 try :
533- last_check_time = time .time ()
534587 while True :
535588 self ._unfrozen_time_out_reqs_tokens ()
536589 time .sleep (3.5 )
537- if last_check_time - time .time () > 10.0 :
538- self .check_trans_process ()
539- last_check_time = time .time ()
540590 except (BaseException , RuntimeError ) as e :
541591 logger .exception (str (e ))
542592 raise e
543593
594+ def check_trans_process_loop (self ):
595+ try :
596+ while True :
597+ self .check_trans_process ()
598+ time .sleep (10.0 )
599+ except (BaseException , RuntimeError ) as e :
600+ logger .exception (str (e ))
601+ # kill parent process if any exception occurred
602+ os .kill (os .getppid (), signal .SIGTERM )
603+ raise e
604+
544605
545606def _init_env (args , info_queue : mp .Queue , mem_queues : List [mp .Queue ], event : mp .Event ):
546607 import lightllm .utils .rpyc_fix_utils as _
@@ -552,6 +613,9 @@ def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.
552613 t = ThreadedServer (manager , port = args .pd_decode_rpyc_port , protocol_config = {"allow_pickle" : True })
553614 threading .Thread (target = lambda : t .start (), daemon = True ).start ()
554615
616+ kv_trans_process_check = threading .Thread (target = manager .check_trans_process_loop , daemon = True )
617+ kv_trans_process_check .start ()
618+
555619 event .set ()
556620 manager .timer_loop ()
557621 return
0 commit comments