2929from lightllm .utils .envs_utils import get_unique_server_name
3030
3131KV_MOVE_MAX_NUM = 16
32- KV_MOVE_MAX_RESTART_CNT = 3
32+ KV_MOVE_MAX_START_CNT = 3
3333
3434logger = init_logger (__name__ )
3535
@@ -348,20 +348,13 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
348348
349349 from .prefill_trans_process import start_prefill_trans_process
350350
351- self .kv_trans_ports = []
352- self .kv_trans_processes = []
353- self .kv_trans_task_in_queues = []
354- self .kv_trans_task_out_queues = []
355- self .kv_trans_process_restart_cnt = []
351+ self .kv_trans_ports = [None ] * self . node_world_size
352+ self .kv_trans_processes = [None ] * self . node_world_size
353+ self .kv_trans_task_in_queues = [None ] * self . node_world_size
354+ self .kv_trans_task_out_queues = [None ] * self . node_world_size
355+ self .kv_trans_process_start_cnt = [0 ] * self . node_world_size
356356
357357 for device_id in range (self .node_world_size ):
358- self .kv_trans_task_in_queues .append (mp .Queue ())
359- self .kv_trans_task_out_queues .append (mp .Queue ())
360- self .kv_trans_ports .append (
361- find_available_port (self .args .pd_p_allowed_port_min , self .args .pd_p_allowed_port_max )
362- )
363- self .kv_trans_process_restart_cnt .append (0 )
364- self .kv_trans_processes .append (None )
365358 assert self .start_trans_process (device_id )
366359
367360 return
@@ -385,10 +378,10 @@ def handle_release_task_loop(self):
385378 return
386379
387380 def start_trans_process (self , device_id : int ):
388- task_in_queue = self . kv_trans_task_in_queues [ device_id ]
389- task_out_queue = self . kv_trans_task_out_queues [ device_id ]
390- kv_trans_port = self .kv_trans_ports [ device_id ]
391- self .kv_trans_process_restart_cnt [device_id ] += 1
381+ task_in_queue = mp . Queue ()
382+ task_out_queue = mp . Queue ()
383+ kv_trans_port = find_available_port ( self .args . pd_p_allowed_port_min , self . args . pd_p_allowed_port_max )
384+ self .kv_trans_process_start_cnt [device_id ] += 1
392385
393386 if self .kv_trans_processes [device_id ]:
394387 # force kill
@@ -417,6 +410,9 @@ def start_trans_process(self, device_id: int):
417410 assert task_out_queue .get (timeout = 60 ) == "get_mem_managers_ok"
418411
419412 self .kv_trans_processes [device_id ] = kv_trans_process
413+ self .kv_trans_task_in_queues [device_id ] = task_in_queue
414+ self .kv_trans_task_out_queues [device_id ] = task_out_queue
415+ self .kv_trans_ports [device_id ] = kv_trans_port
420416
421417 return True
422418 except Exception as e :
@@ -454,7 +450,7 @@ def check_trans_process_loop(self):
454450 raise e
455451
456452 def is_kv_trans_process_alive (self , device_id ):
457- return self .kv_trans_process_restart_cnt [device_id ] <= KV_MOVE_MAX_RESTART_CNT
453+ return self .kv_trans_process_start_cnt [device_id ] <= KV_MOVE_MAX_START_CNT
458454
459455 def get_next_device_index (self ):
460456
0 commit comments