Skip to content

Commit b085bfb

Browse files
committed
fixed race cond
Signed-off-by: Eran Geva <[email protected]>
1 parent 655d78a commit b085bfb

File tree

1 file changed

+29
-15
lines changed
  • tensorrt_llm/_torch/auto_deploy/distributed

1 file changed

+29
-15
lines changed

tensorrt_llm/_torch/auto_deploy/distributed/common.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -197,23 +197,37 @@ def initialize(
197197
if shared_port is not None and port_ready_barrier is not None:
198198
if local_rank == 0:
199199
# Rank 0: try ports until one works, then share with other ranks
200-
for attempt in range(max_retries):
201-
ad_logger.info(
202-
f"Initializing for: {lib=}, {local_rank=}, {world_size=}, {port=} (attempt {attempt + 1})"
203-
)
204-
if _try_init_process_group(local_rank, world_size, port):
205-
# Success! Share the working port with other ranks
206-
shared_port.value = port
207-
port_ready_barrier.wait() # Signal other ranks
208-
break
200+
init_success = False
201+
init_error = None
202+
try:
203+
for attempt in range(max_retries):
204+
ad_logger.info(
205+
f"Initializing for: {lib=}, {local_rank=}, {world_size=}, {port=} (attempt {attempt + 1})"
206+
)
207+
if _try_init_process_group(local_rank, world_size, port):
208+
# Success! Share the working port with other ranks
209+
shared_port.value = port
210+
init_success = True
211+
break
212+
else:
213+
# Port was taken, try a new one
214+
port = get_free_port()
209215
else:
210-
# Port was taken, try a new one
211-
port = get_free_port()
212-
else:
213-
# All retries exhausted
214-
shared_port.value = -1 # Signal failure
216+
# All retries exhausted
217+
init_error = RuntimeError(
218+
f"Failed to find available port after {max_retries} attempts"
219+
)
220+
except Exception as e:
221+
# Catch any unexpected error so we can still signal other ranks
222+
init_error = e
223+
finally:
224+
# ALWAYS signal other ranks, even on error, to prevent deadlock
225+
if not init_success:
226+
shared_port.value = -1
215227
port_ready_barrier.wait()
216-
raise RuntimeError(f"Failed to find available port after {max_retries} attempts")
228+
229+
if init_error is not None:
230+
raise init_error
217231
else:
218232
# Other ranks: wait for rank 0 to find a working port
219233
port_ready_barrier.wait()

0 commit comments

Comments
 (0)