2323CHECK_STATUS_INTERVAL = 3
2424
2525ROUTER_TYPES = ["round_robin" , "load_balancing" , "kv_cache_aware" ]
26- USED_PORTS = set ()
27-
28-
29- # get_free_port doesn't guarantee that consecutive calls will return different ports
30- # if no server is bound to the port immediately after the call
31- def get_free_unused_port ():
32- global USED_PORTS
33- max_attempts = 100
34- for _ in range (max_attempts ):
35- port = get_free_port ()
36- assert port > 0 , f"get_free_port returned { port } "
37- if port not in USED_PORTS :
38- USED_PORTS .add (port )
39- return port
40- else :
41- logger .info (f"Port { port } is already used, trying another one" )
42- raise Exception (
43- f"Failed to find a free unused port after { max_attempts } attempts" )
4426
4527
4628@pytest .fixture
@@ -53,7 +35,7 @@ def model_name():
5335
5436@pytest .fixture
5537def disagg_port ():
56- return get_free_unused_port ()
38+ return get_free_port ()
5739
5840
5941@pytest .fixture
@@ -145,8 +127,6 @@ def _run_worker(model_name,
145127 work_dir ,
146128 device = - 1 ,
147129 save_log = False ):
148- if port == 0 :
149- port = get_free_unused_port ()
150130 worker_config_path = os .path .join (work_dir , f"{ role } _{ port } _config.yaml" )
151131 with open (worker_config_path , "w+" ) as f :
152132 yaml .dump (worker_config , f )
@@ -187,6 +167,7 @@ def _run_worker(model_name,
187167 port = port )
188168
189169
170+ # Use 0 as the port and provide disagg_cluster_config to let the worker choose a free port
190171def run_ctx_worker (model_name , ctx_worker_config , work_dir , port = 0 , device = 0 ):
191172 return _run_worker (model_name , ctx_worker_config , "ctx" , port , work_dir ,
192173 device )
@@ -246,19 +227,38 @@ async def wrapper(*args, **kwargs):
246227 return decorator
247228
248229
249- @periodic_check (timeout = 300 , interval = 3 )
250- async def wait_for_disagg_server_ready (port ):
230+ async def _wait_for_disagg_server_status (port ,
231+ ready = True ,
232+ min_ctx_workers = - 1 ,
233+ min_gen_workers = - 1 ):
251234 info_resp = requests .get (f"http://localhost:{ port } /cluster_info" )
252235 logger .info (
253236 f"Waiting for disagg server { port } to be ready: { info_resp .json ()} " )
254237 if info_resp .status_code == 200 :
255238 info = info_resp .json ()
256- return info ["is_ready" ]
257- else :
258- logger .info (f"Failed to get cluster info: { info_resp .status_code } " )
239+ if ready :
240+ return info ["is_ready" ]
241+ else :
242+ return len (info ["current_workers" ]
243+ ["context_servers" ]) >= min_ctx_workers and len (
244+ info ["current_workers" ]
245+ ["generation_servers" ]) >= min_gen_workers
259246 return False
260247
261248
249+ @periodic_check (timeout = 300 , interval = 3 )
250+ async def wait_for_disagg_server_ready (port ):
251+ return await _wait_for_disagg_server_status (port , True )
252+
253+
254+ @periodic_check (timeout = 300 , interval = 3 )
255+ async def wait_for_disagg_server_status (port ,
256+ min_ctx_workers = - 1 ,
257+ min_gen_workers = - 1 ):
258+ return await _wait_for_disagg_server_status (port , False , min_ctx_workers ,
259+ min_gen_workers )
260+
261+
262262@periodic_check (timeout = 300 , interval = 3 )
263263async def wait_for_worker_ready (port ):
264264 logger .info (f"Waiting for worker { port } to be ready" )
@@ -314,9 +314,6 @@ def terminate(*args, show_log_lines=30, release_port=True):
314314 if arg .log_file :
315315 arg .log_file .close ()
316316 arg .log_file = None
317- if release_port :
318- global USED_PORTS
319- USED_PORTS .discard (arg .port )
320317 except Exception :
321318 print (f"Failed to terminate process { arg .process .pid } " )
322319 else :
@@ -399,8 +396,7 @@ async def test_minimal_instances(model_name, disagg_server_config,
399396 gen_worker1 = run_gen_worker (model_name , worker_config , work_dir )
400397 disagg_server = run_disagg_server (disagg_server_config , work_dir ,
401398 disagg_port )
402- await wait_for_worker_ready (ctx_worker1 .port )
403- await wait_for_worker_ready (gen_worker1 .port )
399+ await wait_for_disagg_server_status (disagg_port , 1 , 1 )
404400 verify_cluster_info (False , 1 , 1 , port = disagg_port )
405401 # with only 1 ctx and 1 gen worker, the request should fail
406402 with pytest .raises (Exception ):
@@ -470,7 +466,7 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
470466 work_dir ,
471467 port = 0 ,
472468 device = 2 )
473- await wait_for_worker_ready ( gen_worker2 . port )
469+ await wait_for_disagg_server_status ( disagg_port , 1 , 1 )
474470 await asyncio .sleep (CHECK_STATUS_INTERVAL )
475471 verify_cluster_info (True , 1 , 1 , port = disagg_port )
476472
@@ -492,7 +488,8 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
492488 work_dir ,
493489 port = 0 ,
494490 device = 3 )
495- await wait_for_worker_ready (ctx_worker2 .port )
491+ await wait_for_disagg_server_status (disagg_port , 1 , 1 )
492+ await asyncio .sleep (CHECK_STATUS_INTERVAL )
496493 verify_cluster_info (True , 1 , 1 , port = disagg_port )
497494
498495 response = request_completion (model_name , test_prompt , port = disagg_port )
@@ -510,8 +507,7 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
510507 work_dir ,
511508 port = 0 ,
512509 device = 1 )
513- await wait_for_worker_ready (ctx_worker1 .port )
514- await wait_for_worker_ready (gen_worker1 .port )
510+ await wait_for_disagg_server_status (disagg_port , 2 , 2 )
515511 await asyncio .sleep (CHECK_STATUS_INTERVAL )
516512 verify_cluster_info (True , 2 , 2 , port = disagg_port )
517513
0 commit comments