|
2 | 2 | # -*- encoding: utf-8 -*- |
3 | 3 |
|
4 | 4 | import os |
| 5 | +import time |
| 6 | +import socket |
| 7 | +from datetime import timedelta |
5 | 8 |
|
6 | 9 | # set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when overlapping communication and computation, |
7 | 10 | # the order of of kernel launches on GPUs are the same as on the CPU so that comm is launched first. |
|
17 | 20 | from colossalai.utils import set_seed |
18 | 21 |
|
19 | 22 |
|
| 23 | +def _wait_for_master_ready(host: str, port: int, timeout: int = 300, retry_interval: int = 5) -> bool: |
| 24 | + """ |
| 25 | + Wait for the master node to be ready for distributed training connections. |
| 26 | + |
| 27 | + This is particularly useful in Kubernetes environments where pods start at different times. |
| 28 | + |
| 29 | + Args: |
| 30 | + host (str): Master node hostname or IP address |
| 31 | + port (int): Master node port |
| 32 | + timeout (int): Maximum time to wait in seconds (default: 300) |
| 33 | + retry_interval (int): Time between connection attempts in seconds (default: 5) |
| 34 | + |
| 35 | + Returns: |
| 36 | + bool: True if master is ready, False if timeout exceeded |
| 37 | + """ |
| 38 | + start_time = time.time() |
| 39 | + logger = get_dist_logger() |
| 40 | + |
| 41 | + while time.time() - start_time < timeout: |
| 42 | + try: |
| 43 | + # Attempt to connect to the master node |
| 44 | + sock = socket.create_connection((host, port), timeout=10) |
| 45 | + sock.close() |
| 46 | + logger.info(f"Master node {host}:{port} is ready for connections", ranks=[0]) |
| 47 | + return True |
| 48 | + except (socket.error, socket.timeout, ConnectionRefusedError, OSError) as e: |
| 49 | + logger.debug(f"Waiting for master node {host}:{port} to be ready... ({e})", ranks=[0]) |
| 50 | + time.sleep(retry_interval) |
| 51 | + |
| 52 | + logger.error(f"Master node {host}:{port} did not become ready within {timeout} seconds", ranks=[0]) |
| 53 | + return False |
| 54 | + |
| 55 | + |
| 56 | +def _get_distributed_timeout() -> timedelta: |
| 57 | + """ |
| 58 | + Get the distributed training timeout from environment variables or use sensible defaults. |
| 59 | + |
| 60 | + Returns: |
| 61 | + timedelta: Timeout for distributed training initialization |
| 62 | + """ |
| 63 | + # Check for user-defined timeout (in seconds) |
| 64 | + timeout_seconds = int(os.environ.get("COLOSSALAI_DIST_TIMEOUT", "1800")) # 30 minutes default |
| 65 | + return timedelta(seconds=timeout_seconds) |
| 66 | + |
| 67 | + |
20 | 68 | def launch( |
21 | 69 | rank: int, |
22 | 70 | world_size: int, |
@@ -48,15 +96,47 @@ def launch( |
48 | 96 | """ |
49 | 97 |
|
50 | 98 | cur_accelerator = get_accelerator() |
51 | | - |
52 | 99 | backend = cur_accelerator.communication_backend |
| 100 | + |
| 101 | + logger = get_dist_logger() if verbose else None |
| 102 | + |
| 103 | + # Wait for master node to be ready (especially important for K8s environments) |
| 104 | + if rank != 0: # Non-master ranks should wait for master to be ready |
| 105 | + if logger: |
| 106 | + logger.info(f"Rank {rank}: Waiting for master node {host}:{port} to be ready...") |
| 107 | + |
| 108 | + master_ready_timeout = int(os.environ.get("COLOSSALAI_MASTER_READY_TIMEOUT", "300")) |
| 109 | + if not _wait_for_master_ready(host, port, timeout=master_ready_timeout): |
| 110 | + raise RuntimeError(f"Master node {host}:{port} is not ready for connections after {master_ready_timeout} seconds") |
53 | 111 |
|
54 | | - # init default process group |
| 112 | + # init default process group with enhanced timeout and error handling |
55 | 113 | if ":" in host: # IPv6 |
56 | 114 | init_method = f"tcp://[{host}]:{port}" |
57 | 115 | else: # IPv4 |
58 | 116 | init_method = f"tcp://{host}:{port}" |
59 | | - dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) |
| 117 | + |
| 118 | + # Get timeout from environment or use default |
| 119 | + timeout = _get_distributed_timeout() |
| 120 | + |
| 121 | + if logger: |
| 122 | + logger.info(f"Initializing distributed process group: rank={rank}, world_size={world_size}, " |
| 123 | + f"backend={backend}, init_method={init_method}, timeout={timeout}") |
| 124 | + |
| 125 | + try: |
| 126 | + dist.init_process_group( |
| 127 | + rank=rank, |
| 128 | + world_size=world_size, |
| 129 | + backend=backend, |
| 130 | + init_method=init_method, |
| 131 | + timeout=timeout |
| 132 | + ) |
| 133 | + except Exception as e: |
| 134 | + if logger: |
| 135 | + logger.error(f"Failed to initialize distributed process group: {e}") |
| 136 | + logger.error(f"Please check: 1) Master node {host}:{port} is accessible, " |
| 137 | + f"2) All nodes use the same MASTER_ADDR/MASTER_PORT, " |
| 138 | + f"3) Network connectivity between nodes") |
| 139 | + raise RuntimeError(f"Distributed initialization failed: {e}") from e |
60 | 140 |
|
61 | 141 | # set cuda device |
62 | 142 | # if local rank is not given, calculate automatically |
@@ -161,16 +241,66 @@ def launch_from_torch(backend: str = "nccl", seed: int = 1024, verbose: bool = T |
161 | 241 | seed (int, optional): Specified random seed for every process. Defaults to 1024. |
162 | 242 | verbose (bool, optional): Whether to print logs. Defaults to True. |
163 | 243 | """ |
| 244 | + logger = get_dist_logger() if verbose else None |
| 245 | + |
| 246 | + # Validate required environment variables with detailed error messages |
| 247 | + required_envs = { |
| 248 | + "RANK": "Global rank of the current process", |
| 249 | + "LOCAL_RANK": "Local rank of the process on the current node", |
| 250 | + "WORLD_SIZE": "Total number of processes across all nodes", |
| 251 | + "MASTER_ADDR": "IP address or hostname of the master node", |
| 252 | + "MASTER_PORT": "Port number for distributed communication" |
| 253 | + } |
| 254 | + |
| 255 | + missing_envs = [] |
| 256 | + for env_var, description in required_envs.items(): |
| 257 | + if env_var not in os.environ: |
| 258 | + missing_envs.append(f" - {env_var}: {description}") |
| 259 | + |
| 260 | + if missing_envs: |
| 261 | + error_msg = ("Missing required environment variables for distributed training:\n" + |
| 262 | + "\n".join(missing_envs) + |
| 263 | + "\n\nFor Kubernetes multi-node training, ensure you're using enhanced torchrun command:\n" |
| 264 | + "torchrun --nnodes=N --nproc_per_node=M --rdzv_backend=c10d \\\n" |
| 265 | + " --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT --rdzv_id=$JOB_ID \\\n" |
| 266 | + " --node_rank=$NODE_RANK your_script.py\n\n" |
| 267 | + "Visit https://www.colossalai.org/ for more information on launching with torch") |
| 268 | + raise RuntimeError(error_msg) |
| 269 | + |
164 | 270 | try: |
165 | 271 | rank = int(os.environ["RANK"]) |
166 | 272 | local_rank = int(os.environ["LOCAL_RANK"]) |
167 | 273 | world_size = int(os.environ["WORLD_SIZE"]) |
168 | 274 | host = os.environ["MASTER_ADDR"] |
169 | 275 | port = int(os.environ["MASTER_PORT"]) |
170 | | - except KeyError as e: |
171 | | - raise RuntimeError( |
172 | | - f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch" |
173 | | - ) |
| 276 | + except ValueError as e: |
| 277 | + raise RuntimeError(f"Invalid environment variable value: {e}. All rank and port values must be integers.") |
| 278 | + |
| 279 | + # Additional validation for common misconfigurations |
| 280 | + if rank >= world_size: |
| 281 | + raise RuntimeError(f"RANK ({rank}) must be less than WORLD_SIZE ({world_size})") |
| 282 | + |
| 283 | + if local_rank < 0: |
| 284 | + raise RuntimeError(f"LOCAL_RANK ({local_rank}) must be non-negative") |
| 285 | + |
| 286 | + if port < 1024 or port > 65535: |
| 287 | + raise RuntimeError(f"MASTER_PORT ({port}) must be between 1024 and 65535") |
| 288 | + |
| 289 | + # Log distributed training configuration for debugging |
| 290 | + if logger and verbose: |
| 291 | + logger.info(f"Starting distributed training with configuration:") |
| 292 | + logger.info(f" RANK: {rank}") |
| 293 | + logger.info(f" LOCAL_RANK: {local_rank}") |
| 294 | + logger.info(f" WORLD_SIZE: {world_size}") |
| 295 | + logger.info(f" MASTER_ADDR: {host}") |
| 296 | + logger.info(f" MASTER_PORT: {port}") |
| 297 | + logger.info(f" BACKEND: {backend}") |
| 298 | + |
| 299 | + # Log additional environment variables that might be relevant for debugging |
| 300 | + debug_envs = ["NODE_RANK", "NCCL_DEBUG", "GLOO_SOCKET_IFNAME", "NCCL_SOCKET_IFNAME", "RDZV_ID"] |
| 301 | + for env_var in debug_envs: |
| 302 | + if env_var in os.environ: |
| 303 | + logger.info(f" {env_var}: {os.environ[env_var]}") |
174 | 304 |
|
175 | 305 | launch( |
176 | 306 | local_rank=local_rank, |
|
0 commit comments