@@ -85,9 +85,21 @@ def get_rank_world_size() -> Tuple[int, int]:
8585 return get_rank (), get_world_size ()
8686
8787
88- def initialize_or_skip (* args , ** kwargs ) -> Tuple [int , int ]:
88+ def initialize_or_skip (
89+ rank : int = 0 ,
90+ world_size : int = 1 ,
91+ port : Optional [int ] = None ,
92+ shared_port : Optional ["mp.Value" ] = None ,
93+ port_ready_barrier : Optional ["mp.Barrier" ] = None ,
94+ ) -> Tuple [int , int ]:
8995 if not dist .is_initialized ():
90- return initialize (* args , ** kwargs )
96+ return initialize (
97+ rank = rank ,
98+ world_size = world_size ,
99+ port = port ,
100+ shared_port = shared_port ,
101+ port_ready_barrier = port_ready_barrier ,
102+ )
91103 return get_rank (), get_world_size ()
92104
93105
@@ -112,7 +124,48 @@ def cleanup():
112124 dist .destroy_process_group ()
113125
114126
115- def initialize (rank : int = 0 , world_size : int = 1 , port : Optional [int ] = None ) -> Tuple [int , int ]:
127+ def _try_init_process_group (local_rank : int , world_size : int , port : int ) -> bool :
128+ """Attempt to initialize process group. Returns True on success, False on EADDRINUSE."""
129+ os .environ ["RANK" ] = str (local_rank )
130+ os .environ ["WORLD_SIZE" ] = str (world_size )
131+ os .environ ["MASTER_ADDR" ] = "127.0.0.1"
132+ os .environ ["MASTER_PORT" ] = str (port )
133+ os .environ ["LOCAL_RANK" ] = str (local_rank )
134+
135+ try :
136+ dist .init_process_group (
137+ "nccl" ,
138+ world_size = world_size ,
139+ rank = local_rank ,
140+ device_id = torch .device (local_rank ),
141+ )
142+ return True
143+ except Exception as e :
144+ # Check if this is a port-in-use error (only rank 0 binds, so only rank 0 can get this)
145+ if "EADDRINUSE" in str (e ) or "address already in use" in str (e ).lower ():
146+ ad_logger .warning (f"Port { port } already in use, will retry with new port" )
147+ return False
148+ raise
149+
150+
151+ def initialize (
152+ rank : int = 0 ,
153+ world_size : int = 1 ,
154+ port : Optional [int ] = None ,
155+ shared_port : Optional ["mp.Value" ] = None ,
156+ port_ready_barrier : Optional ["mp.Barrier" ] = None ,
157+ max_retries : int = 5 ,
158+ ) -> Tuple [int , int ]:
159+ """Initialize distributed process group.
160+
161+ Args:
162+ rank: Process rank (ignored for OMPI/torchelastic).
163+ world_size: Total number of processes (ignored for OMPI/torchelastic).
164+ port: Initial port to try. If None, a free port will be selected.
165+ shared_port: Optional mp.Value for rank 0 to share the final port with other ranks.
166+ port_ready_barrier: Optional mp.Barrier to synchronize port selection.
167+ max_retries: Maximum number of port retry attempts for rank 0.
168+ """
116169 if is_ompi ():
117170 lib = "OMPI"
118171 local_rank = int (os .environ ["OMPI_COMM_WORLD_LOCAL_RANK" ])
@@ -131,25 +184,53 @@ def initialize(rank: int = 0, world_size: int = 1, port: Optional[int] = None) -
131184 port = get_free_port ()
132185
133186 ad_logger .set_rank (local_rank )
134- ad_logger .info (f"Initializing for: { lib = } , { local_rank = } , { world_size = } , { port = } " )
135-
136- # Set up environment variable to run with mpirun
137- os .environ ["RANK" ] = str (local_rank )
138- os .environ ["WORLD_SIZE" ] = str (world_size )
139- os .environ ["MASTER_ADDR" ] = "127.0.0.1"
140- os .environ ["MASTER_PORT" ] = str (port )
141- os .environ ["LOCAL_RANK" ] = str (local_rank )
142187
143188 # Necessary to assign a device to each rank.
144189 torch .cuda .set_device (local_rank )
145190
146- # We use nccl backend
147- dist .init_process_group (
148- "nccl" ,
149- world_size = world_size ,
150- rank = local_rank ,
151- device_id = torch .device (local_rank ),
152- )
191+ # If we have shared port synchronization (multiprocess spawn mode)
192+ if shared_port is not None and port_ready_barrier is not None :
193+ if local_rank == 0 :
194+ # Rank 0: try ports until one works, then share with other ranks
195+ for attempt in range (max_retries ):
196+ ad_logger .info (
197+ f"Initializing for: { lib = } , { local_rank = } , { world_size = } , { port = } (attempt { attempt + 1 } )"
198+ )
199+ if _try_init_process_group (local_rank , world_size , port ):
200+ # Success! Share the working port with other ranks
201+ shared_port .value = port
202+ port_ready_barrier .wait () # Signal other ranks
203+ break
204+ else :
205+ # Port was taken, try a new one
206+ port = get_free_port ()
207+ else :
208+ # All retries exhausted
209+ shared_port .value = - 1 # Signal failure
210+ port_ready_barrier .wait ()
211+ raise RuntimeError (f"Failed to find available port after { max_retries } attempts" )
212+ else :
213+ # Other ranks: wait for rank 0 to find a working port
214+ port_ready_barrier .wait ()
215+ port = shared_port .value
216+ if port == - 1 :
217+ raise RuntimeError ("Rank 0 failed to initialize, cannot proceed" )
218+ ad_logger .info (f"Initializing for: { lib = } , { local_rank = } , { world_size = } , { port = } " )
219+ dist .init_process_group (
220+ "nccl" ,
221+ world_size = world_size ,
222+ rank = local_rank ,
223+ device_id = torch .device (local_rank ),
224+ )
225+ else :
226+ # Original path: no retry mechanism (OMPI, torchelastic, or single process)
227+ ad_logger .info (f"Initializing for: { lib = } , { local_rank = } , { world_size = } , { port = } " )
228+ dist .init_process_group (
229+ "nccl" ,
230+ world_size = world_size ,
231+ rank = local_rank ,
232+ device_id = torch .device (local_rank ),
233+ )
153234
154235 # Register cleanup function to be called at exit
155236 atexit .register (cleanup )
@@ -160,9 +241,13 @@ def initialize(rank: int = 0, world_size: int = 1, port: Optional[int] = None) -
160241 return local_rank , world_size
161242
162243
163- def init_and_run_process (job , rank , size , port , ** kwargs ):
244+ def init_and_run_process (
245+ job , rank , size , port , shared_port = None , port_ready_barrier = None , ** kwargs
246+ ):
164247 try :
165- initialize_or_skip (rank , size , port )
248+ initialize_or_skip (
249+ rank , size , port , shared_port = shared_port , port_ready_barrier = port_ready_barrier
250+ )
166251 job (rank , size , ** kwargs )
167252 except Exception as e :
168253 # Close the input and output queues to parent process can exit.
@@ -212,19 +297,27 @@ def _start_multiprocess_job(
212297 init_and_run_process (job , 0 , 1 , port , ** kwargs )
213298 return None
214299
215- mp .set_start_method ("spawn" , force = True )
300+ # Use explicit spawn context to ensure synchronization primitives work correctly
301+ ctx = mp .get_context ("spawn" )
216302 processes : List [mp .Process ] = []
217303
304+ # Create shared state for port synchronization with retry mechanism:
305+ # - shared_port: rank 0 writes the final working port here
306+ # - port_ready_barrier: all ranks wait here until rank 0 has bound successfully
307+ shared_port = ctx .Value ("i" , port ) # 'i' = signed int
308+ port_ready_barrier = ctx .Barrier (size )
309+
218310 for rank in range (size ):
219311 if input_queues :
220312 kwargs ["input_queue" ] = input_queues [rank ]
221313 if output_queue :
222314 kwargs ["output_queue" ] = output_queue if rank == 0 else None
223315
224- # Use thread for the single worker case.
225- launch_method = mp .Process
226- p = launch_method (
227- target = init_and_run_process , args = (job , rank , size , port ), kwargs = kwargs , daemon = True
316+ p = ctx .Process (
317+ target = init_and_run_process ,
318+ args = (job , rank , size , port ),
319+ kwargs = {** kwargs , "shared_port" : shared_port , "port_ready_barrier" : port_ready_barrier },
320+ daemon = True ,
228321 )
229322 p .start ()
230323 processes .append (p )
0 commit comments