Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/multimodal/components/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def signal_handler():
args, config = VllmBaseWorker.parse_args()

# vLLM config overwrites
await configure_ports(runtime, config)
configure_ports(config)
overwrite_args(config)
await init(runtime, args, config)

Expand Down
116 changes: 28 additions & 88 deletions examples/multimodal/utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@
# SPDX-License-Identifier: Apache-2.0

import argparse
import json
import logging
import os
import socket
import sys
import time
from typing import Callable, List, Optional, Tuple

from vllm.config import KVTransferConfig
from vllm.distributed.kv_events import KVEventsConfig
from vllm.engine.arg_utils import AsyncEngineArgs

from dynamo.runtime import DistributedRuntime

logger = logging.getLogger(__name__)

DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
Expand All @@ -30,7 +26,6 @@ class Config:
component: str
endpoint: str
kv_port: Optional[int] = None
side_channel_port: Optional[int] = None

# mirror vLLM
model: str
Expand Down Expand Up @@ -115,76 +110,45 @@ def base_parse_args(
return args, config


async def allocate_and_reserve_port(
runtime: DistributedRuntime,
namespace: str,
worker_id: str,
reason: str,
) -> int:
"""
Get an OS-assigned port and atomically reserve it.
Retries until successful or internal max attempts reached.
"""
def get_kv_port() -> int:
"""Get KV events port from environment or default."""
return int(os.getenv("DYN_VLLM_KV_EVENT_PORT", "20080"))

context_json = {
"worker_id": worker_id,
"reason": reason,
"reserved_at": time.time(),
"pid": os.getpid(),
"block_size": 1,
}

# Any ephemeral port, equivalent to binding port 0
port_range_min = 32_768
port_range_max = 60_999
allocated_ports = await runtime.allocate_port_block(
namespace,
port_range_min,
port_range_max,
1, # how many ports to allocate
json.dumps(context_json),
)
if not allocated_ports:
raise RuntimeError("allocate_port_block returned no ports")
port = allocated_ports[0]
logger.debug(f"Reserved OS-assigned port {port} for {worker_id}")
return port
def ensure_side_channel_host():
"""Ensure the NIXL side-channel host is available without overriding user settings."""
existing_host = os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST")
if existing_host:
logger.debug(
"Preserving existing VLLM_NIXL_SIDE_CHANNEL_HOST=%s", existing_host
)
return

try:
host_name = socket.gethostname()
host_ip = socket.gethostbyname(host_name)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket:
test_socket.bind((host_ip, 0))
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip
logger.debug("Set VLLM_NIXL_SIDE_CHANNEL_HOST to %s", host_ip)
except (socket.error, socket.gaierror):
logger.warning("Failed to get hostname, falling back to 127.0.0.1")
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = "127.0.0.1"

async def configure_ports(runtime: DistributedRuntime, config: Config):
"""Configure including port allocation and vLLM overrides."""

# First, allocate ports
dp_rank = config.engine_args.data_parallel_rank or 0
worker_id = f"vllm-{config.component}-dp{dp_rank}"

# Allocate KV events port
kv_port = await allocate_and_reserve_port(
runtime=runtime,
namespace=config.namespace,
worker_id=f"{worker_id}",
reason="zmq_kv_event_port",
)
def configure_ports(config: Config):
"""Configure port settings from dedicated environment overrides."""

# Allocate side channel port
side_channel_port = await allocate_and_reserve_port(
runtime=runtime,
namespace=config.namespace,
worker_id=f"{worker_id}",
reason="nixl_side_channel_port",
)
if config.engine_args.enable_prefix_caching:
config.kv_port = get_kv_port()

# Update config with allocated ports
config.kv_port = kv_port
config.side_channel_port = side_channel_port
ensure_side_channel_host()


def overwrite_args(config):
"""Set vLLM defaults for Dynamo."""
assert config.kv_port is not None, "Must set the kv_port, use configure_ports"
assert (
config.side_channel_port is not None
), "Must set the side_channel_port, use configure_ports"
if config.engine_args.enable_prefix_caching:
assert config.kv_port is not None, "Must set the kv_port, use configure_ports"

dp_rank = config.engine_args.data_parallel_rank or 0

Expand All @@ -206,34 +170,10 @@ def overwrite_args(config):
),
}

set_side_channel_host_and_port(config)

logger.debug("Setting Dynamo defaults for vLLM")
for key, value in defaults.items():
if hasattr(config.engine_args, key):
setattr(config.engine_args, key, value)
logger.debug(f" engine_args.{key} = {value}")
else:
raise ValueError(f"{key} not found in AsyncEngineArgs from vLLM.")


def set_side_channel_host_and_port(config: Config, hostname: Optional[str] = None):
"""vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors.
This sets the port number for the side channel.
"""
if hostname is None:
hostname = socket.gethostname()
# Test if hostname is usable by attempting to bind to it
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket:
test_socket.bind((hostname, 0))
except (socket.error, socket.gaierror):
# If hostname is not usable, fall back to localhost
logger.warning(
f"Hostname '{hostname}' is not usable, falling back to '127.0.0.1'"
)
hostname = "127.0.0.1"

os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = hostname
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(config.side_channel_port)
logger.debug(f"Set NIXL side channel to {hostname}:{config.side_channel_port}")
Loading