Skip to content

Commit e3bd35d

Browse files
committed
allow using 0 port when disagg service discovery is enabled
Signed-off-by: Lizhi Zhou <[email protected]>
1 parent 4a1b742 commit e3bd35d

File tree

2 files changed

+53
-40
lines changed

2 files changed

+53
-40
lines changed

tensorrt_llm/commands/serve.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tensorrt_llm import LLM as PyTorchLLM
1919
from tensorrt_llm import MultimodalEncoder
2020
from tensorrt_llm._tensorrt_engine import LLM
21-
from tensorrt_llm._utils import mpi_rank
21+
from tensorrt_llm._utils import get_free_port, mpi_rank
2222
from tensorrt_llm.executor.utils import LlmLauncherEnvs
2323
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
2424
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
@@ -180,10 +180,27 @@ def launch_server(
180180
backend = llm_args["backend"]
181181
model = llm_args["model"]
182182
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
183-
try:
184-
s.bind((host, port))
185-
except OSError as e:
186-
raise RuntimeError(f"Failed to bind socket to {host}:{port}: {e}")
183+
# If disagg cluster config is provided and port is not specified, try to find a free port, otherwise try to bind to the specified port
184+
assert port > 0 or disagg_cluster_config is not None, "Port must be specified if disagg cluster config is not provided"
185+
if port > 0:
186+
port_retries = 1
187+
else:
188+
port_retries = 100
189+
port = get_free_port()
190+
while port_retries > 0:
191+
try:
192+
s.bind((host, port))
193+
break
194+
except OSError as e:
195+
port_retries -= 1
196+
if port_retries == 0:
197+
raise RuntimeError(
198+
f"Failed to bind socket to {host}:{port}: {e}")
199+
else:
200+
logger.warning(
201+
f"Failed to bind socket to {host}:{port}: {e}, retrying {port_retries}..."
202+
)
203+
port = get_free_port()
187204

188205
if backend == 'pytorch':
189206
llm_args.pop("build_config", None)

tests/integration/defs/disaggregated/test_auto_scaling.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,6 @@
2323
CHECK_STATUS_INTERVAL = 3
2424

2525
ROUTER_TYPES = ["round_robin", "load_balancing", "kv_cache_aware"]
26-
USED_PORTS = set()
27-
28-
29-
# get_free_port doesn't guarantee that consecutive calls will return different ports
30-
# if no server is bound to the port immediately after the call
31-
def get_free_unused_port():
32-
global USED_PORTS
33-
max_attempts = 100
34-
for _ in range(max_attempts):
35-
port = get_free_port()
36-
assert port > 0, f"get_free_port returned {port}"
37-
if port not in USED_PORTS:
38-
USED_PORTS.add(port)
39-
return port
40-
else:
41-
logger.info(f"Port {port} is already used, trying another one")
42-
raise Exception(
43-
f"Failed to find a free unused port after {max_attempts} attempts")
4426

4527

4628
@pytest.fixture
@@ -53,7 +35,7 @@ def model_name():
5335

5436
@pytest.fixture
5537
def disagg_port():
56-
return get_free_unused_port()
38+
return get_free_port()
5739

5840

5941
@pytest.fixture
@@ -145,8 +127,6 @@ def _run_worker(model_name,
145127
work_dir,
146128
device=-1,
147129
save_log=False):
148-
if port == 0:
149-
port = get_free_unused_port()
150130
worker_config_path = os.path.join(work_dir, f"{role}_{port}_config.yaml")
151131
with open(worker_config_path, "w+") as f:
152132
yaml.dump(worker_config, f)
@@ -187,6 +167,7 @@ def _run_worker(model_name,
187167
port=port)
188168

189169

170+
# Use 0 as the port and provide disagg_cluster_config to let the worker choose a free port
190171
def run_ctx_worker(model_name, ctx_worker_config, work_dir, port=0, device=0):
191172
return _run_worker(model_name, ctx_worker_config, "ctx", port, work_dir,
192173
device)
@@ -246,19 +227,38 @@ async def wrapper(*args, **kwargs):
246227
return decorator
247228

248229

249-
@periodic_check(timeout=300, interval=3)
250-
async def wait_for_disagg_server_ready(port):
230+
async def _wait_for_disagg_server_status(port,
231+
ready=True,
232+
min_ctx_workers=-1,
233+
min_gen_workers=-1):
251234
info_resp = requests.get(f"http://localhost:{port}/cluster_info")
252235
logger.info(
253236
f"Waiting for disagg server {port} to be ready: {info_resp.json()}")
254237
if info_resp.status_code == 200:
255238
info = info_resp.json()
256-
return info["is_ready"]
257-
else:
258-
logger.info(f"Failed to get cluster info: {info_resp.status_code}")
239+
if ready:
240+
return info["is_ready"]
241+
else:
242+
return len(info["current_workers"]
243+
["context_servers"]) >= min_ctx_workers and len(
244+
info["current_workers"]
245+
["generation_servers"]) >= min_gen_workers
259246
return False
260247

261248

249+
@periodic_check(timeout=300, interval=3)
250+
async def wait_for_disagg_server_ready(port):
251+
return await _wait_for_disagg_server_status(port, True)
252+
253+
254+
@periodic_check(timeout=300, interval=3)
255+
async def wait_for_disagg_server_status(port,
256+
min_ctx_workers=-1,
257+
min_gen_workers=-1):
258+
return await _wait_for_disagg_server_status(port, False, min_ctx_workers,
259+
min_gen_workers)
260+
261+
262262
@periodic_check(timeout=300, interval=3)
263263
async def wait_for_worker_ready(port):
264264
logger.info(f"Waiting for worker {port} to be ready")
@@ -314,9 +314,6 @@ def terminate(*args, show_log_lines=30, release_port=True):
314314
if arg.log_file:
315315
arg.log_file.close()
316316
arg.log_file = None
317-
if release_port:
318-
global USED_PORTS
319-
USED_PORTS.discard(arg.port)
320317
except Exception:
321318
print(f"Failed to terminate process {arg.process.pid}")
322319
else:
@@ -399,8 +396,7 @@ async def test_minimal_instances(model_name, disagg_server_config,
399396
gen_worker1 = run_gen_worker(model_name, worker_config, work_dir)
400397
disagg_server = run_disagg_server(disagg_server_config, work_dir,
401398
disagg_port)
402-
await wait_for_worker_ready(ctx_worker1.port)
403-
await wait_for_worker_ready(gen_worker1.port)
399+
await wait_for_disagg_server_status(disagg_port, 1, 1)
404400
verify_cluster_info(False, 1, 1, port=disagg_port)
405401
# with only 1 ctx and 1 gen worker, the request should fail
406402
with pytest.raises(Exception):
@@ -470,7 +466,7 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
470466
work_dir,
471467
port=0,
472468
device=2)
473-
await wait_for_worker_ready(gen_worker2.port)
469+
await wait_for_disagg_server_status(disagg_port, 1, 1)
474470
await asyncio.sleep(CHECK_STATUS_INTERVAL)
475471
verify_cluster_info(True, 1, 1, port=disagg_port)
476472

@@ -492,7 +488,8 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
492488
work_dir,
493489
port=0,
494490
device=3)
495-
await wait_for_worker_ready(ctx_worker2.port)
491+
await wait_for_disagg_server_status(disagg_port, 1, 1)
492+
await asyncio.sleep(CHECK_STATUS_INTERVAL)
496493
verify_cluster_info(True, 1, 1, port=disagg_port)
497494

498495
response = request_completion(model_name, test_prompt, port=disagg_port)
@@ -510,8 +507,7 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
510507
work_dir,
511508
port=0,
512509
device=1)
513-
await wait_for_worker_ready(ctx_worker1.port)
514-
await wait_for_worker_ready(gen_worker1.port)
510+
await wait_for_disagg_server_status(disagg_port, 2, 2)
515511
await asyncio.sleep(CHECK_STATUS_INTERVAL)
516512
verify_cluster_info(True, 2, 2, port=disagg_port)
517513

0 commit comments

Comments
 (0)