1212import pytest
1313import requests
1414import yaml
15+ from defs .common import get_free_port_in_ci as get_free_port
1516from defs .conftest import llm_models_root
1617
17- from tensorrt_llm ._utils import get_free_port
1818from tensorrt_llm .logger import logger
1919
2020HEARTBEAT_INTERVAL = 1
@@ -454,7 +454,7 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
454454 port = disagg_port )
455455 print (response )
456456 # kill gen1, the request should fail
457- terminate (gen_worker1 , release_port = False )
457+ terminate (gen_worker1 , release_port = True )
458458 await asyncio .sleep (CHECK_STATUS_INTERVAL )
459459 verify_cluster_info (False , 1 , 0 , port = disagg_port )
460460 with pytest .raises (Exception ):
@@ -480,7 +480,7 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
480480 assert len (response .choices [0 ].text ) >= 1
481481
482482 # kill ctx1, the request should fail
483- terminate (ctx_worker1 , release_port = False )
483+ terminate (ctx_worker1 , release_port = True )
484484 await asyncio .sleep (CHECK_STATUS_INTERVAL )
485485 verify_cluster_info (False , 0 , 1 , port = disagg_port )
486486 with pytest .raises (Exception ):
@@ -500,16 +500,16 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
500500 assert len (response .choices [0 ].text ) >= 1
501501
502502 # start ctx1 and gen1 again, we have 2 ctxs and 2 gens now
503- await wait_for_port_released (ctx_worker1 .port )
504- await wait_for_port_released (gen_worker1 .port )
505503 ctx_worker1 = run_ctx_worker (model_name ,
506504 worker_config ,
507505 work_dir ,
508- port = ctx_worker1 .port )
506+ port = 0 ,
507+ device = 0 )
509508 gen_worker1 = run_gen_worker (model_name ,
510509 worker_config ,
511510 work_dir ,
512- port = gen_worker1 .port )
511+ port = 0 ,
512+ device = 1 )
513513 await wait_for_worker_ready (ctx_worker1 .port )
514514 await wait_for_worker_ready (gen_worker1 .port )
515515 await asyncio .sleep (CHECK_STATUS_INTERVAL )
@@ -556,6 +556,7 @@ async def test_disagg_server_restart(model_name, disagg_server_config,
556556 terminate (disagg_server )
557557 # wait for the port to be released, so we can rebind the new process to the same port
558558 await wait_for_port_released (disagg_port )
559+ await asyncio .sleep (CHECK_STATUS_INTERVAL )
559560
560561 with pytest .raises (requests .exceptions .RequestException ):
561562 verify_cluster_info (False ,
0 commit comments