Skip to content

Commit e519caf

Browse files
committed
fixed race condition in port acquisition
Signed-off-by: Eran Geva <[email protected]>
1 parent 0d2e271 commit e519caf

File tree

2 files changed

+139
-46
lines changed

2 files changed

+139
-46
lines changed

tensorrt_llm/_torch/auto_deploy/distributed/common.py

Lines changed: 118 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,21 @@ def get_rank_world_size() -> Tuple[int, int]:
8585
return get_rank(), get_world_size()
8686

8787

88-
def initialize_or_skip(*args, **kwargs) -> Tuple[int, int]:
88+
def initialize_or_skip(
89+
rank: int = 0,
90+
world_size: int = 1,
91+
port: Optional[int] = None,
92+
shared_port: Optional["mp.Value"] = None,
93+
port_ready_barrier: Optional["mp.Barrier"] = None,
94+
) -> Tuple[int, int]:
8995
if not dist.is_initialized():
90-
return initialize(*args, **kwargs)
96+
return initialize(
97+
rank=rank,
98+
world_size=world_size,
99+
port=port,
100+
shared_port=shared_port,
101+
port_ready_barrier=port_ready_barrier,
102+
)
91103
return get_rank(), get_world_size()
92104

93105

@@ -112,7 +124,48 @@ def cleanup():
112124
dist.destroy_process_group()
113125

114126

115-
def initialize(rank: int = 0, world_size: int = 1, port: Optional[int] = None) -> Tuple[int, int]:
127+
def _try_init_process_group(local_rank: int, world_size: int, port: int) -> bool:
128+
"""Attempt to initialize process group. Returns True on success, False on EADDRINUSE."""
129+
os.environ["RANK"] = str(local_rank)
130+
os.environ["WORLD_SIZE"] = str(world_size)
131+
os.environ["MASTER_ADDR"] = "127.0.0.1"
132+
os.environ["MASTER_PORT"] = str(port)
133+
os.environ["LOCAL_RANK"] = str(local_rank)
134+
135+
try:
136+
dist.init_process_group(
137+
"nccl",
138+
world_size=world_size,
139+
rank=local_rank,
140+
device_id=torch.device(local_rank),
141+
)
142+
return True
143+
except Exception as e:
144+
# Check if this is a port-in-use error (only rank 0 binds, so only rank 0 can get this)
145+
if "EADDRINUSE" in str(e) or "address already in use" in str(e).lower():
146+
ad_logger.warning(f"Port {port} already in use, will retry with new port")
147+
return False
148+
raise
149+
150+
151+
def initialize(
152+
rank: int = 0,
153+
world_size: int = 1,
154+
port: Optional[int] = None,
155+
shared_port: Optional["mp.Value"] = None,
156+
port_ready_barrier: Optional["mp.Barrier"] = None,
157+
max_retries: int = 5,
158+
) -> Tuple[int, int]:
159+
"""Initialize distributed process group.
160+
161+
Args:
162+
rank: Process rank (ignored for OMPI/torchelastic).
163+
world_size: Total number of processes (ignored for OMPI/torchelastic).
164+
port: Initial port to try. If None, a free port will be selected.
165+
shared_port: Optional mp.Value for rank 0 to share the final port with other ranks.
166+
port_ready_barrier: Optional mp.Barrier to synchronize port selection.
167+
max_retries: Maximum number of port retry attempts for rank 0.
168+
"""
116169
if is_ompi():
117170
lib = "OMPI"
118171
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
@@ -131,25 +184,53 @@ def initialize(rank: int = 0, world_size: int = 1, port: Optional[int] = None) -
131184
port = get_free_port()
132185

133186
ad_logger.set_rank(local_rank)
134-
ad_logger.info(f"Initializing for: {lib=}, {local_rank=}, {world_size=}, {port=}")
135-
136-
# Set up environment variable to run with mpirun
137-
os.environ["RANK"] = str(local_rank)
138-
os.environ["WORLD_SIZE"] = str(world_size)
139-
os.environ["MASTER_ADDR"] = "127.0.0.1"
140-
os.environ["MASTER_PORT"] = str(port)
141-
os.environ["LOCAL_RANK"] = str(local_rank)
142187

143188
# Necessary to assign a device to each rank.
144189
torch.cuda.set_device(local_rank)
145190

146-
# We use nccl backend
147-
dist.init_process_group(
148-
"nccl",
149-
world_size=world_size,
150-
rank=local_rank,
151-
device_id=torch.device(local_rank),
152-
)
191+
# If we have shared port synchronization (multiprocess spawn mode)
192+
if shared_port is not None and port_ready_barrier is not None:
193+
if local_rank == 0:
194+
# Rank 0: try ports until one works, then share with other ranks
195+
for attempt in range(max_retries):
196+
ad_logger.info(
197+
f"Initializing for: {lib=}, {local_rank=}, {world_size=}, {port=} (attempt {attempt + 1})"
198+
)
199+
if _try_init_process_group(local_rank, world_size, port):
200+
# Success! Share the working port with other ranks
201+
shared_port.value = port
202+
port_ready_barrier.wait() # Signal other ranks
203+
break
204+
else:
205+
# Port was taken, try a new one
206+
port = get_free_port()
207+
else:
208+
# All retries exhausted
209+
shared_port.value = -1 # Signal failure
210+
port_ready_barrier.wait()
211+
raise RuntimeError(f"Failed to find available port after {max_retries} attempts")
212+
else:
213+
# Other ranks: wait for rank 0 to find a working port
214+
port_ready_barrier.wait()
215+
port = shared_port.value
216+
if port == -1:
217+
raise RuntimeError("Rank 0 failed to initialize, cannot proceed")
218+
ad_logger.info(f"Initializing for: {lib=}, {local_rank=}, {world_size=}, {port=}")
219+
dist.init_process_group(
220+
"nccl",
221+
world_size=world_size,
222+
rank=local_rank,
223+
device_id=torch.device(local_rank),
224+
)
225+
else:
226+
# Original path: no retry mechanism (OMPI, torchelastic, or single process)
227+
ad_logger.info(f"Initializing for: {lib=}, {local_rank=}, {world_size=}, {port=}")
228+
dist.init_process_group(
229+
"nccl",
230+
world_size=world_size,
231+
rank=local_rank,
232+
device_id=torch.device(local_rank),
233+
)
153234

154235
# Register cleanup function to be called at exit
155236
atexit.register(cleanup)
@@ -160,9 +241,13 @@ def initialize(rank: int = 0, world_size: int = 1, port: Optional[int] = None) -
160241
return local_rank, world_size
161242

162243

163-
def init_and_run_process(job, rank, size, port, **kwargs):
244+
def init_and_run_process(
245+
job, rank, size, port, shared_port=None, port_ready_barrier=None, **kwargs
246+
):
164247
try:
165-
initialize_or_skip(rank, size, port)
248+
initialize_or_skip(
249+
rank, size, port, shared_port=shared_port, port_ready_barrier=port_ready_barrier
250+
)
166251
job(rank, size, **kwargs)
167252
except Exception as e:
168253
# Close the input and output queues to parent process can exit.
@@ -212,19 +297,27 @@ def _start_multiprocess_job(
212297
init_and_run_process(job, 0, 1, port, **kwargs)
213298
return None
214299

215-
mp.set_start_method("spawn", force=True)
300+
# Use explicit spawn context to ensure synchronization primitives work correctly
301+
ctx = mp.get_context("spawn")
216302
processes: List[mp.Process] = []
217303

304+
# Create shared state for port synchronization with retry mechanism:
305+
# - shared_port: rank 0 writes the final working port here
306+
# - port_ready_barrier: all ranks wait here until rank 0 has bound successfully
307+
shared_port = ctx.Value("i", port) # 'i' = signed int
308+
port_ready_barrier = ctx.Barrier(size)
309+
218310
for rank in range(size):
219311
if input_queues:
220312
kwargs["input_queue"] = input_queues[rank]
221313
if output_queue:
222314
kwargs["output_queue"] = output_queue if rank == 0 else None
223315

224-
# Use thread for the single worker case.
225-
launch_method = mp.Process
226-
p = launch_method(
227-
target=init_and_run_process, args=(job, rank, size, port), kwargs=kwargs, daemon=True
316+
p = ctx.Process(
317+
target=init_and_run_process,
318+
args=(job, rank, size, port),
319+
kwargs={**kwargs, "shared_port": shared_port, "port_ready_barrier": port_ready_barrier},
320+
daemon=True,
228321
)
229322
p.start()
230323
processes.append(p)

tensorrt_llm/_torch/custom_ops/cpp_custom_ops.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -752,27 +752,27 @@ def _(input_list, group, num_lists):
752752
for i in range(0, len(input_list), num_ranks)
753753
]
754754

755-
@torch.library.register_fake("trtllm::alltoall_helix_native")
756-
def _(partial_o, softmax_stats, workspace, cp_rank, cp_size):
757-
# Returns outputs with same shapes as inputs
758-
return partial_o.new_empty(partial_o.shape), softmax_stats.new_empty(
759-
softmax_stats.shape)
760-
761-
@torch.library.register_fake("trtllm::initialize_helix_workspace")
762-
def _(workspace, cp_rank, cp_size):
763-
# This op initializes workspace in-place and returns nothing
764-
return None
765-
766-
@torch.library.register_fake("trtllm::helix_post_process")
767-
def _(gathered_o, gathered_stats, scale):
768-
return gathered_o.new_empty(*gathered_o.shape[1:])
769-
770-
@torch.library.register_fake("trtllm::helix_post_process_native")
771-
def _(gathered_o, gathered_stats, scale, cp_dim):
772-
# Remove the dimension at cp_dim (context parallelism dimension)
773-
out_shape = list(gathered_o.shape)
774-
del out_shape[cp_dim]
775-
return gathered_o.new_empty(*out_shape)
755+
# @torch.library.register_fake("trtllm::alltoall_helix_native")
756+
# def _(partial_o, softmax_stats, workspace, cp_rank, cp_size):
757+
# # Returns outputs with same shapes as inputs
758+
# return partial_o.new_empty(partial_o.shape), softmax_stats.new_empty(
759+
# softmax_stats.shape)
760+
761+
# @torch.library.register_fake("trtllm::initialize_helix_workspace")
762+
# def _(workspace, cp_rank, cp_size):
763+
# # This op initializes workspace in-place and returns nothing
764+
# return None
765+
766+
# @torch.library.register_fake("trtllm::helix_post_process")
767+
# def _(gathered_o, gathered_stats, scale):
768+
# return gathered_o.new_empty(*gathered_o.shape[1:])
769+
770+
# @torch.library.register_fake("trtllm::helix_post_process_native")
771+
# def _(gathered_o, gathered_stats, scale, cp_dim):
772+
# # Remove the dimension at cp_dim (context parallelism dimension)
773+
# out_shape = list(gathered_o.shape)
774+
# del out_shape[cp_dim]
775+
# return gathered_o.new_empty(*out_shape)
776776

777777
@torch.library.register_fake("trtllm::tinygemm2")
778778
def _(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):

0 commit comments

Comments
 (0)