Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions tensorrt_llm/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from tensorrt_llm import LLM as PyTorchLLM
from tensorrt_llm import MultimodalEncoder
from tensorrt_llm._tensorrt_engine import LLM
from tensorrt_llm._utils import mpi_rank
from tensorrt_llm._utils import get_free_port, mpi_rank
from tensorrt_llm.executor.utils import LlmLauncherEnvs
from tensorrt_llm.inputs.multimodal import MultimodalServerConfig
from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy,
Expand Down Expand Up @@ -180,10 +180,27 @@ def launch_server(
backend = llm_args["backend"]
model = llm_args["model"]
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind((host, port))
except OSError as e:
raise RuntimeError(f"Failed to bind socket to {host}:{port}: {e}")
# If disagg cluster config is provided and port is not specified, try to find a free port, otherwise try to bind to the specified port
assert port > 0 or disagg_cluster_config is not None, "Port must be specified if disagg cluster config is not provided"
if port > 0:
port_retries = 1
else:
port_retries = 100
port = get_free_port()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does disagg server needs to know this arbitrary port?

while port_retries > 0:
try:
s.bind((host, port))
break
except OSError as e:
port_retries -= 1
if port_retries == 0:
raise RuntimeError(
f"Failed to bind socket to {host}:{port}: {e}")
else:
logger.warning(
f"Failed to bind socket to {host}:{port}: {e}, retrying {port_retries}..."
)
port = get_free_port()

if backend == 'pytorch':
llm_args.pop("build_config", None)
Expand Down
66 changes: 31 additions & 35 deletions tests/integration/defs/disaggregated/test_auto_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,6 @@
CHECK_STATUS_INTERVAL = 3

ROUTER_TYPES = ["round_robin", "load_balancing", "kv_cache_aware"]
USED_PORTS = set()


# get_free_port doesn't guarantee that consecutive calls will return different ports
# if no server is bound to the port immediately after the call
def get_free_unused_port():
global USED_PORTS
max_attempts = 100
for _ in range(max_attempts):
port = get_free_port()
assert port > 0, f"get_free_port returned {port}"
if port not in USED_PORTS:
USED_PORTS.add(port)
return port
else:
logger.info(f"Port {port} is already used, trying another one")
raise Exception(
f"Failed to find a free unused port after {max_attempts} attempts")


@pytest.fixture
Expand All @@ -53,7 +35,7 @@ def model_name():

@pytest.fixture
def disagg_port():
return get_free_unused_port()
return get_free_port()


@pytest.fixture
Expand Down Expand Up @@ -145,8 +127,6 @@ def _run_worker(model_name,
work_dir,
device=-1,
save_log=False):
if port == 0:
port = get_free_unused_port()
worker_config_path = os.path.join(work_dir, f"{role}_{port}_config.yaml")
with open(worker_config_path, "w+") as f:
yaml.dump(worker_config, f)
Expand Down Expand Up @@ -187,6 +167,7 @@ def _run_worker(model_name,
port=port)


# Use 0 as the port and provide disagg_cluster_config to let the worker choose a free port
def run_ctx_worker(model_name, ctx_worker_config, work_dir, port=0, device=0):
return _run_worker(model_name, ctx_worker_config, "ctx", port, work_dir,
device)
Expand Down Expand Up @@ -246,19 +227,38 @@ async def wrapper(*args, **kwargs):
return decorator


@periodic_check(timeout=300, interval=3)
async def wait_for_disagg_server_ready(port):
async def _wait_for_disagg_server_status(port,
ready=True,
min_ctx_workers=-1,
min_gen_workers=-1):
info_resp = requests.get(f"http://localhost:{port}/cluster_info")
logger.info(
f"Waiting for disagg server {port} to be ready: {info_resp.json()}")
if info_resp.status_code == 200:
info = info_resp.json()
return info["is_ready"]
else:
logger.info(f"Failed to get cluster info: {info_resp.status_code}")
if ready:
return info["is_ready"]
else:
return len(info["current_workers"]
["context_servers"]) >= min_ctx_workers and len(
info["current_workers"]
["generation_servers"]) >= min_gen_workers
return False


@periodic_check(timeout=300, interval=3)
async def wait_for_disagg_server_ready(port):
return await _wait_for_disagg_server_status(port, True)


@periodic_check(timeout=300, interval=3)
async def wait_for_disagg_server_status(port,
min_ctx_workers=-1,
min_gen_workers=-1):
return await _wait_for_disagg_server_status(port, False, min_ctx_workers,
min_gen_workers)


@periodic_check(timeout=300, interval=3)
async def wait_for_worker_ready(port):
logger.info(f"Waiting for worker {port} to be ready")
Expand Down Expand Up @@ -314,9 +314,6 @@ def terminate(*args, show_log_lines=30, release_port=True):
if arg.log_file:
arg.log_file.close()
arg.log_file = None
if release_port:
global USED_PORTS
USED_PORTS.discard(arg.port)
except Exception:
print(f"Failed to terminate process {arg.process.pid}")
else:
Expand Down Expand Up @@ -399,8 +396,7 @@ async def test_minimal_instances(model_name, disagg_server_config,
gen_worker1 = run_gen_worker(model_name, worker_config, work_dir)
disagg_server = run_disagg_server(disagg_server_config, work_dir,
disagg_port)
await wait_for_worker_ready(ctx_worker1.port)
await wait_for_worker_ready(gen_worker1.port)
await wait_for_disagg_server_status(disagg_port, 1, 1)
verify_cluster_info(False, 1, 1, port=disagg_port)
# with only 1 ctx and 1 gen worker, the request should fail
with pytest.raises(Exception):
Expand Down Expand Up @@ -470,7 +466,7 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
work_dir,
port=0,
device=2)
await wait_for_worker_ready(gen_worker2.port)
await wait_for_disagg_server_status(disagg_port, 1, 1)
await asyncio.sleep(CHECK_STATUS_INTERVAL)
verify_cluster_info(True, 1, 1, port=disagg_port)

Expand All @@ -492,7 +488,8 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
work_dir,
port=0,
device=3)
await wait_for_worker_ready(ctx_worker2.port)
await wait_for_disagg_server_status(disagg_port, 1, 1)
await asyncio.sleep(CHECK_STATUS_INTERVAL)
verify_cluster_info(True, 1, 1, port=disagg_port)

response = request_completion(model_name, test_prompt, port=disagg_port)
Expand All @@ -510,8 +507,7 @@ async def test_worker_restart(model_name, disagg_server_config, worker_config,
work_dir,
port=0,
device=1)
await wait_for_worker_ready(ctx_worker1.port)
await wait_for_worker_ready(gen_worker1.port)
await wait_for_disagg_server_status(disagg_port, 2, 2)
await asyncio.sleep(CHECK_STATUS_INTERVAL)
verify_cluster_info(True, 2, 2, port=disagg_port)

Expand Down
Loading