diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 6943df0c1ab..0dd86b66ac3 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -18,7 +18,7 @@ from tensorrt_llm import LLM as PyTorchLLM from tensorrt_llm import MultimodalEncoder from tensorrt_llm._tensorrt_engine import LLM -from tensorrt_llm._utils import mpi_rank +from tensorrt_llm._utils import get_free_port, mpi_rank from tensorrt_llm.executor.utils import LlmLauncherEnvs from tensorrt_llm.inputs.multimodal import MultimodalServerConfig from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy, @@ -180,10 +180,27 @@ def launch_server( backend = llm_args["backend"] model = llm_args["model"] with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - try: - s.bind((host, port)) - except OSError as e: - raise RuntimeError(f"Failed to bind socket to {host}:{port}: {e}") + # If disagg cluster config is provided and port is not specified, try to find a free port, otherwise try to bind to the specified port + assert port > 0 or disagg_cluster_config is not None, "Port must be specified if disagg cluster config is not provided" + if port > 0: + port_retries = 1 + else: + port_retries = 100 + port = get_free_port() + while port_retries > 0: + try: + s.bind((host, port)) + break + except OSError as e: + port_retries -= 1 + if port_retries == 0: + raise RuntimeError( + f"Failed to bind socket to {host}:{port}: {e}") + else: + logger.warning( + f"Failed to bind socket to {host}:{port}: {e}, retrying {port_retries}..." + ) + port = get_free_port() if backend == 'pytorch': llm_args.pop("build_config", None) diff --git a/tests/integration/defs/disaggregated/test_auto_scaling.py b/tests/integration/defs/disaggregated/test_auto_scaling.py index 52838c8b20e..1dd37a17a19 100644 --- a/tests/integration/defs/disaggregated/test_auto_scaling.py +++ b/tests/integration/defs/disaggregated/test_auto_scaling.py @@ -23,24 +23,6 @@ CHECK_STATUS_INTERVAL = 3 ROUTER_TYPES = ["round_robin", "load_balancing", "kv_cache_aware"] -USED_PORTS = set() - - -# get_free_port doesn't guarantee that consecutive calls will return different ports -# if no server is bound to the port immediately after the call -def get_free_unused_port(): - global USED_PORTS - max_attempts = 100 - for _ in range(max_attempts): - port = get_free_port() - assert port > 0, f"get_free_port returned {port}" - if port not in USED_PORTS: - USED_PORTS.add(port) - return port - else: - logger.info(f"Port {port} is already used, trying another one") - raise Exception( - f"Failed to find a free unused port after {max_attempts} attempts") @pytest.fixture @@ -53,7 +35,7 @@ def model_name(): @pytest.fixture def disagg_port(): - return get_free_unused_port() + return get_free_port() @pytest.fixture @@ -145,8 +127,6 @@ def _run_worker(model_name, work_dir, device=-1, save_log=False): - if port == 0: - port = get_free_unused_port() worker_config_path = os.path.join(work_dir, f"{role}_{port}_config.yaml") with open(worker_config_path, "w+") as f: yaml.dump(worker_config, f) @@ -187,6 +167,7 @@ def _run_worker(model_name, port=port) +# Use 0 as the port and provide disagg_cluster_config to let the worker choose a free port def run_ctx_worker(model_name, ctx_worker_config, work_dir, port=0, device=0): return _run_worker(model_name, ctx_worker_config, "ctx", port, work_dir, device) @@ -246,19 +227,38 @@ async def wrapper(*args, **kwargs): return decorator -@periodic_check(timeout=300, interval=3) -async def wait_for_disagg_server_ready(port): +async def _wait_for_disagg_server_status(port, + ready=True, + min_ctx_workers=-1, + min_gen_workers=-1): info_resp = requests.get(f"http://localhost:{port}/cluster_info") logger.info( f"Waiting for disagg server {port} to be ready: {info_resp.json()}") if info_resp.status_code == 200: info = info_resp.json() - return info["is_ready"] - else: - logger.info(f"Failed to get cluster info: {info_resp.status_code}") + if ready: + return info["is_ready"] + else: + return len(info["current_workers"] + ["context_servers"]) >= min_ctx_workers and len( + info["current_workers"] + ["generation_servers"]) >= min_gen_workers return False +@periodic_check(timeout=300, interval=3) +async def wait_for_disagg_server_ready(port): + return await _wait_for_disagg_server_status(port, True) + + +@periodic_check(timeout=300, interval=3) +async def wait_for_disagg_server_status(port, + min_ctx_workers=-1, + min_gen_workers=-1): + return await _wait_for_disagg_server_status(port, False, min_ctx_workers, + min_gen_workers) + + @periodic_check(timeout=300, interval=3) async def wait_for_worker_ready(port): logger.info(f"Waiting for worker {port} to be ready") @@ -314,9 +314,6 @@ def terminate(*args, show_log_lines=30, release_port=True): if arg.log_file: arg.log_file.close() arg.log_file = None - if release_port: - global USED_PORTS - USED_PORTS.discard(arg.port) except Exception: print(f"Failed to terminate process {arg.process.pid}") else: @@ -399,8 +396,7 @@ async def test_minimal_instances(model_name, disagg_server_config, gen_worker1 = run_gen_worker(model_name, worker_config, work_dir) disagg_server = run_disagg_server(disagg_server_config, work_dir, disagg_port) - await wait_for_worker_ready(ctx_worker1.port) - await wait_for_worker_ready(gen_worker1.port) + await wait_for_disagg_server_status(disagg_port, 1, 1) verify_cluster_info(False, 1, 1, port=disagg_port) # with only 1 ctx and 1 gen worker, the request should fail with pytest.raises(Exception): @@ -470,7 +466,7 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config, work_dir, port=0, device=2) - await wait_for_worker_ready(gen_worker2.port) + await wait_for_disagg_server_status(disagg_port, 1, 1) await asyncio.sleep(CHECK_STATUS_INTERVAL) verify_cluster_info(True, 1, 1, port=disagg_port) @@ -492,7 +488,8 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config, work_dir, port=0, device=3) - await wait_for_worker_ready(ctx_worker2.port) + await wait_for_disagg_server_status(disagg_port, 1, 1) + await asyncio.sleep(CHECK_STATUS_INTERVAL) verify_cluster_info(True, 1, 1, port=disagg_port) response = request_completion(model_name, test_prompt, port=disagg_port) @@ -510,8 +507,7 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config, work_dir, port=0, device=1) - await wait_for_worker_ready(ctx_worker1.port) - await wait_for_worker_ready(gen_worker1.port) + await wait_for_disagg_server_status(disagg_port, 2, 2) await asyncio.sleep(CHECK_STATUS_INTERVAL) verify_cluster_info(True, 2, 2, port=disagg_port)