22# -*- encoding: utf-8 -*-
33
44import os
5- import time
65import socket
6+ import time
77from datetime import timedelta
88
99# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when overlapping communication and computation,
2323def _wait_for_master_ready (host : str , port : int , timeout : int = 300 , retry_interval : int = 5 ) -> bool :
2424 """
2525 Wait for the master node to be ready for distributed training connections.
26-
26+
2727 This is particularly useful in Kubernetes environments where pods start at different times.
28-
28+
2929 Args:
3030 host (str): Master node hostname or IP address
3131 port (int): Master node port
3232 timeout (int): Maximum time to wait in seconds (default: 300)
3333 retry_interval (int): Time between connection attempts in seconds (default: 5)
34-
34+
3535 Returns:
3636 bool: True if master is ready, False if timeout exceeded
3737 """
3838 start_time = time .time ()
3939 logger = get_dist_logger ()
40-
40+
4141 while time .time () - start_time < timeout :
4242 try :
4343 # Attempt to connect to the master node
@@ -48,15 +48,15 @@ def _wait_for_master_ready(host: str, port: int, timeout: int = 300, retry_inter
4848 except (socket .error , socket .timeout , ConnectionRefusedError , OSError ) as e :
4949 logger .debug (f"Waiting for master node { host } :{ port } to be ready... ({ e } )" , ranks = [0 ])
5050 time .sleep (retry_interval )
51-
51+
5252 logger .error (f"Master node { host } :{ port } did not become ready within { timeout } seconds" , ranks = [0 ])
5353 return False
5454
5555
5656def _get_distributed_timeout () -> timedelta :
5757 """
5858 Get the distributed training timeout from environment variables or use sensible defaults.
59-
59+
6060 Returns:
6161 timedelta: Timeout for distributed training initialization
6262 """
@@ -97,45 +97,47 @@ def launch(
9797
9898 cur_accelerator = get_accelerator ()
9999 backend = cur_accelerator .communication_backend
100-
100+
101101 logger = get_dist_logger () if verbose else None
102102
103103 # Wait for master node to be ready (especially important for K8s environments)
104104 if rank != 0 : # Non-master ranks should wait for master to be ready
105105 if logger :
106106 logger .info (f"Rank { rank } : Waiting for master node { host } :{ port } to be ready..." )
107-
107+
108108 master_ready_timeout = int (os .environ .get ("COLOSSALAI_MASTER_READY_TIMEOUT" , "300" ))
109109 if not _wait_for_master_ready (host , port , timeout = master_ready_timeout ):
110- raise RuntimeError (f"Master node { host } :{ port } is not ready for connections after { master_ready_timeout } seconds" )
110+ raise RuntimeError (
111+ f"Master node { host } :{ port } is not ready for connections after { master_ready_timeout } seconds"
112+ )
111113
112114 # init default process group with enhanced timeout and error handling
113115 if ":" in host : # IPv6
114116 init_method = f"tcp://[{ host } ]:{ port } "
115117 else : # IPv4
116118 init_method = f"tcp://{ host } :{ port } "
117-
119+
118120 # Get timeout from environment or use default
119121 timeout = _get_distributed_timeout ()
120-
122+
121123 if logger :
122- logger .info (f"Initializing distributed process group: rank={ rank } , world_size={ world_size } , "
123- f"backend={ backend } , init_method={ init_method } , timeout={ timeout } " )
124-
124+ logger .info (
125+ f"Initializing distributed process group: rank={ rank } , world_size={ world_size } , "
126+ f"backend={ backend } , init_method={ init_method } , timeout={ timeout } "
127+ )
128+
125129 try :
126130 dist .init_process_group (
127- rank = rank ,
128- world_size = world_size ,
129- backend = backend ,
130- init_method = init_method ,
131- timeout = timeout
131+ rank = rank , world_size = world_size , backend = backend , init_method = init_method , timeout = timeout
132132 )
133133 except Exception as e :
134134 if logger :
135135 logger .error (f"Failed to initialize distributed process group: { e } " )
136- logger .error (f"Please check: 1) Master node { host } :{ port } is accessible, "
137- f"2) All nodes use the same MASTER_ADDR/MASTER_PORT, "
138- f"3) Network connectivity between nodes" )
136+ logger .error (
137+ f"Please check: 1) Master node { host } :{ port } is accessible, "
138+ f"2) All nodes use the same MASTER_ADDR/MASTER_PORT, "
139+ f"3) Network connectivity between nodes"
140+ )
139141 raise RuntimeError (f"Distributed initialization failed: { e } " ) from e
140142
141143 # set cuda device
@@ -242,29 +244,31 @@ def launch_from_torch(backend: str = "nccl", seed: int = 1024, verbose: bool = T
242244 verbose (bool, optional): Whether to print logs. Defaults to True.
243245 """
244246 logger = get_dist_logger () if verbose else None
245-
247+
246248 # Validate required environment variables with detailed error messages
247249 required_envs = {
248250 "RANK" : "Global rank of the current process" ,
249- "LOCAL_RANK" : "Local rank of the process on the current node" ,
251+ "LOCAL_RANK" : "Local rank of the process on the current node" ,
250252 "WORLD_SIZE" : "Total number of processes across all nodes" ,
251253 "MASTER_ADDR" : "IP address or hostname of the master node" ,
252- "MASTER_PORT" : "Port number for distributed communication"
254+ "MASTER_PORT" : "Port number for distributed communication" ,
253255 }
254-
256+
255257 missing_envs = []
256258 for env_var , description in required_envs .items ():
257259 if env_var not in os .environ :
258260 missing_envs .append (f" - { env_var } : { description } " )
259-
261+
260262 if missing_envs :
261- error_msg = ("Missing required environment variables for distributed training:\n " +
262- "\n " .join (missing_envs ) +
263- "\n \n For Kubernetes multi-node training, ensure you're using enhanced torchrun command:\n "
264- "torchrun --nnodes=N --nproc_per_node=M --rdzv_backend=c10d \\ \n "
265- " --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT --rdzv_id=$JOB_ID \\ \n "
266- " --node_rank=$NODE_RANK your_script.py\n \n "
267- "Visit https://www.colossalai.org/ for more information on launching with torch" )
263+ error_msg = (
264+ "Missing required environment variables for distributed training:\n "
265+ + "\n " .join (missing_envs )
266+ + "\n \n For Kubernetes multi-node training, ensure you're using enhanced torchrun command:\n "
267+ "torchrun --nnodes=N --nproc_per_node=M --rdzv_backend=c10d \\ \n "
268+ " --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT --rdzv_id=$JOB_ID \\ \n "
269+ " --node_rank=$NODE_RANK your_script.py\n \n "
270+ "Visit https://www.colossalai.org/ for more information on launching with torch"
271+ )
268272 raise RuntimeError (error_msg )
269273
270274 try :
@@ -275,17 +279,17 @@ def launch_from_torch(backend: str = "nccl", seed: int = 1024, verbose: bool = T
275279 port = int (os .environ ["MASTER_PORT" ])
276280 except ValueError as e :
277281 raise RuntimeError (f"Invalid environment variable value: { e } . All rank and port values must be integers." )
278-
282+
279283 # Additional validation for common misconfigurations
280284 if rank >= world_size :
281285 raise RuntimeError (f"RANK ({ rank } ) must be less than WORLD_SIZE ({ world_size } )" )
282-
286+
283287 if local_rank < 0 :
284288 raise RuntimeError (f"LOCAL_RANK ({ local_rank } ) must be non-negative" )
285-
289+
286290 if port < 1024 or port > 65535 :
287291 raise RuntimeError (f"MASTER_PORT ({ port } ) must be between 1024 and 65535" )
288-
292+
289293 # Log distributed training configuration for debugging
290294 if logger and verbose :
291295 logger .info (f"Starting distributed training with configuration:" )
@@ -295,7 +299,7 @@ def launch_from_torch(backend: str = "nccl", seed: int = 1024, verbose: bool = T
295299 logger .info (f" MASTER_ADDR: { host } " )
296300 logger .info (f" MASTER_PORT: { port } " )
297301 logger .info (f" BACKEND: { backend } " )
298-
302+
299303 # Log additional environment variables that might be relevant for debugging
300304 debug_envs = ["NODE_RANK" , "NCCL_DEBUG" , "GLOO_SOCKET_IFNAME" , "NCCL_SOCKET_IFNAME" , "RDZV_ID" ]
301305 for env_var in debug_envs :
0 commit comments