Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 139 additions & 25 deletions tensorrt_llm/_torch/auto_deploy/distributed/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,21 @@ def get_rank_world_size() -> Tuple[int, int]:
return get_rank(), get_world_size()


def initialize_or_skip(*args, **kwargs) -> Tuple[int, int]:
def initialize_or_skip(
rank: int = 0,
world_size: int = 1,
port: Optional[int] = None,
shared_port: Optional["mp.Value"] = None,
port_ready_barrier: Optional["mp.Barrier"] = None,
) -> Tuple[int, int]:
if not dist.is_initialized():
return initialize(*args, **kwargs)
return initialize(
rank=rank,
world_size=world_size,
port=port,
shared_port=shared_port,
port_ready_barrier=port_ready_barrier,
)
return get_rank(), get_world_size()


Expand All @@ -112,7 +124,53 @@ def cleanup():
dist.destroy_process_group()


def initialize(rank: int = 0, world_size: int = 1, port: Optional[int] = None) -> Tuple[int, int]:
def _set_distributed_env_vars(local_rank: int, world_size: int, port: int) -> None:
"""Set environment variables required by NCCL's env:// init method."""
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["LOCAL_RANK"] = str(local_rank)


def _try_init_process_group(local_rank: int, world_size: int, port: int) -> bool:
"""Attempt to initialize process group. Returns True on success, False on EADDRINUSE."""
_set_distributed_env_vars(local_rank, world_size, port)

try:
dist.init_process_group(
"nccl",
world_size=world_size,
rank=local_rank,
device_id=torch.device(local_rank),
)
return True
except Exception as e:
# Check if this is a port-in-use error (only rank 0 binds, so only rank 0 can get this)
if "EADDRINUSE" in str(e) or "address already in use" in str(e).lower():
ad_logger.warning(f"Port {port} already in use, will retry with new port")
return False
raise


def initialize(
rank: int = 0,
world_size: int = 1,
port: Optional[int] = None,
shared_port: Optional["mp.Value"] = None,
port_ready_barrier: Optional["mp.Barrier"] = None,
max_retries: int = 5,
) -> Tuple[int, int]:
"""Initialize distributed process group.

Args:
rank: Process rank (ignored for OMPI/torchelastic).
world_size: Total number of processes (ignored for OMPI/torchelastic).
port: Initial port to try. If None, a free port will be selected.
shared_port: Optional mp.Value for rank 0 to share the final port with other ranks.
port_ready_barrier: Optional mp.Barrier to synchronize port selection.
max_retries: Maximum number of port retry attempts for rank 0.
"""
if is_ompi():
lib = "OMPI"
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
Expand All @@ -131,25 +189,69 @@ def initialize(rank: int = 0, world_size: int = 1, port: Optional[int] = None) -
port = get_free_port()

ad_logger.set_rank(local_rank)
ad_logger.info(f"Initializing for: {lib=}, {local_rank=}, {world_size=}, {port=}")

# Set up environment variable to run with mpirun
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["LOCAL_RANK"] = str(local_rank)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)

# We use nccl backend
dist.init_process_group(
"nccl",
world_size=world_size,
rank=local_rank,
device_id=torch.device(local_rank),
)
# If we have shared port synchronization (multiprocess spawn mode)
if shared_port is not None and port_ready_barrier is not None:
if local_rank == 0:
# Rank 0: try ports until one works, then share with other ranks
init_success = False
init_error = None
try:
for attempt in range(max_retries):
ad_logger.info(
f"Initializing for: {lib=}, {local_rank=}, {world_size=}, {port=} (attempt {attempt + 1})"
)
if _try_init_process_group(local_rank, world_size, port):
# Success! Share the working port with other ranks
shared_port.value = port
init_success = True
break
else:
# Port was taken, try a new one
port = get_free_port()
else:
# All retries exhausted
init_error = RuntimeError(
f"Failed to find available port after {max_retries} attempts"
)
except Exception as e:
# Catch any unexpected error so we can still signal other ranks
init_error = e
finally:
# ALWAYS signal other ranks, even on error, to prevent deadlock
if not init_success:
shared_port.value = -1
port_ready_barrier.wait()

if init_error is not None:
raise init_error
else:
# Other ranks: wait for rank 0 to find a working port
port_ready_barrier.wait()
port = shared_port.value
if port == -1:
raise RuntimeError("Rank 0 failed to initialize, cannot proceed")
ad_logger.info(f"Initializing for: {lib=}, {local_rank=}, {world_size=}, {port=}")
_set_distributed_env_vars(local_rank, world_size, port)
dist.init_process_group(
"nccl",
world_size=world_size,
rank=local_rank,
device_id=torch.device(local_rank),
)
else:
# Original path: no retry mechanism (OMPI, torchelastic, or single process)
ad_logger.info(f"Initializing for: {lib=}, {local_rank=}, {world_size=}, {port=}")
_set_distributed_env_vars(local_rank, world_size, port)
dist.init_process_group(
"nccl",
world_size=world_size,
rank=local_rank,
device_id=torch.device(local_rank),
)

# Register cleanup function to be called at exit
atexit.register(cleanup)
Expand All @@ -160,9 +262,13 @@ def initialize(rank: int = 0, world_size: int = 1, port: Optional[int] = None) -
return local_rank, world_size


def init_and_run_process(job, rank, size, port, **kwargs):
def init_and_run_process(
job, rank, size, port, shared_port=None, port_ready_barrier=None, **kwargs
):
try:
initialize_or_skip(rank, size, port)
initialize_or_skip(
rank, size, port, shared_port=shared_port, port_ready_barrier=port_ready_barrier
)
job(rank, size, **kwargs)
except Exception as e:
# Close the input and output queues to parent process can exit.
Expand Down Expand Up @@ -212,19 +318,27 @@ def _start_multiprocess_job(
init_and_run_process(job, 0, 1, port, **kwargs)
return None

mp.set_start_method("spawn", force=True)
# Use explicit spawn context to ensure synchronization primitives work correctly
ctx = mp.get_context("spawn")
processes: List[mp.Process] = []

# Create shared state for port synchronization with retry mechanism:
# - shared_port: rank 0 writes the final working port here
# - port_ready_barrier: all ranks wait here until rank 0 has bound successfully
shared_port = ctx.Value("i", port) # 'i' = signed int
port_ready_barrier = ctx.Barrier(size)

for rank in range(size):
if input_queues:
kwargs["input_queue"] = input_queues[rank]
if output_queue:
kwargs["output_queue"] = output_queue if rank == 0 else None

# Use thread for the single worker case.
launch_method = mp.Process
p = launch_method(
target=init_and_run_process, args=(job, rank, size, port), kwargs=kwargs, daemon=True
p = ctx.Process(
target=init_and_run_process,
args=(job, rank, size, port),
kwargs={**kwargs, "shared_port": shared_port, "port_ready_barrier": port_ready_barrier},
daemon=True,
)
p.start()
processes.append(p)
Expand Down