Skip to content

Commit fa14399

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent c2bb086 commit fa14399

File tree

9 files changed

+331
-341
lines changed

9 files changed

+331
-341
lines changed

colossalai/initialize.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
# -*- encoding: utf-8 -*-
33

44
import os
5-
import time
65
import socket
6+
import time
77
from datetime import timedelta
88

99
# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when overlapping communication and computation,
@@ -23,21 +23,21 @@
2323
def _wait_for_master_ready(host: str, port: int, timeout: int = 300, retry_interval: int = 5) -> bool:
2424
"""
2525
Wait for the master node to be ready for distributed training connections.
26-
26+
2727
This is particularly useful in Kubernetes environments where pods start at different times.
28-
28+
2929
Args:
3030
host (str): Master node hostname or IP address
3131
port (int): Master node port
3232
timeout (int): Maximum time to wait in seconds (default: 300)
3333
retry_interval (int): Time between connection attempts in seconds (default: 5)
34-
34+
3535
Returns:
3636
bool: True if master is ready, False if timeout exceeded
3737
"""
3838
start_time = time.time()
3939
logger = get_dist_logger()
40-
40+
4141
while time.time() - start_time < timeout:
4242
try:
4343
# Attempt to connect to the master node
@@ -48,15 +48,15 @@ def _wait_for_master_ready(host: str, port: int, timeout: int = 300, retry_inter
4848
except (socket.error, socket.timeout, ConnectionRefusedError, OSError) as e:
4949
logger.debug(f"Waiting for master node {host}:{port} to be ready... ({e})", ranks=[0])
5050
time.sleep(retry_interval)
51-
51+
5252
logger.error(f"Master node {host}:{port} did not become ready within {timeout} seconds", ranks=[0])
5353
return False
5454

5555

5656
def _get_distributed_timeout() -> timedelta:
5757
"""
5858
Get the distributed training timeout from environment variables or use sensible defaults.
59-
59+
6060
Returns:
6161
timedelta: Timeout for distributed training initialization
6262
"""
@@ -97,45 +97,47 @@ def launch(
9797

9898
cur_accelerator = get_accelerator()
9999
backend = cur_accelerator.communication_backend
100-
100+
101101
logger = get_dist_logger() if verbose else None
102102

103103
# Wait for master node to be ready (especially important for K8s environments)
104104
if rank != 0: # Non-master ranks should wait for master to be ready
105105
if logger:
106106
logger.info(f"Rank {rank}: Waiting for master node {host}:{port} to be ready...")
107-
107+
108108
master_ready_timeout = int(os.environ.get("COLOSSALAI_MASTER_READY_TIMEOUT", "300"))
109109
if not _wait_for_master_ready(host, port, timeout=master_ready_timeout):
110-
raise RuntimeError(f"Master node {host}:{port} is not ready for connections after {master_ready_timeout} seconds")
110+
raise RuntimeError(
111+
f"Master node {host}:{port} is not ready for connections after {master_ready_timeout} seconds"
112+
)
111113

112114
# init default process group with enhanced timeout and error handling
113115
if ":" in host: # IPv6
114116
init_method = f"tcp://[{host}]:{port}"
115117
else: # IPv4
116118
init_method = f"tcp://{host}:{port}"
117-
119+
118120
# Get timeout from environment or use default
119121
timeout = _get_distributed_timeout()
120-
122+
121123
if logger:
122-
logger.info(f"Initializing distributed process group: rank={rank}, world_size={world_size}, "
123-
f"backend={backend}, init_method={init_method}, timeout={timeout}")
124-
124+
logger.info(
125+
f"Initializing distributed process group: rank={rank}, world_size={world_size}, "
126+
f"backend={backend}, init_method={init_method}, timeout={timeout}"
127+
)
128+
125129
try:
126130
dist.init_process_group(
127-
rank=rank,
128-
world_size=world_size,
129-
backend=backend,
130-
init_method=init_method,
131-
timeout=timeout
131+
rank=rank, world_size=world_size, backend=backend, init_method=init_method, timeout=timeout
132132
)
133133
except Exception as e:
134134
if logger:
135135
logger.error(f"Failed to initialize distributed process group: {e}")
136-
logger.error(f"Please check: 1) Master node {host}:{port} is accessible, "
137-
f"2) All nodes use the same MASTER_ADDR/MASTER_PORT, "
138-
f"3) Network connectivity between nodes")
136+
logger.error(
137+
f"Please check: 1) Master node {host}:{port} is accessible, "
138+
f"2) All nodes use the same MASTER_ADDR/MASTER_PORT, "
139+
f"3) Network connectivity between nodes"
140+
)
139141
raise RuntimeError(f"Distributed initialization failed: {e}") from e
140142

141143
# set cuda device
@@ -242,29 +244,31 @@ def launch_from_torch(backend: str = "nccl", seed: int = 1024, verbose: bool = T
242244
verbose (bool, optional): Whether to print logs. Defaults to True.
243245
"""
244246
logger = get_dist_logger() if verbose else None
245-
247+
246248
# Validate required environment variables with detailed error messages
247249
required_envs = {
248250
"RANK": "Global rank of the current process",
249-
"LOCAL_RANK": "Local rank of the process on the current node",
251+
"LOCAL_RANK": "Local rank of the process on the current node",
250252
"WORLD_SIZE": "Total number of processes across all nodes",
251253
"MASTER_ADDR": "IP address or hostname of the master node",
252-
"MASTER_PORT": "Port number for distributed communication"
254+
"MASTER_PORT": "Port number for distributed communication",
253255
}
254-
256+
255257
missing_envs = []
256258
for env_var, description in required_envs.items():
257259
if env_var not in os.environ:
258260
missing_envs.append(f" - {env_var}: {description}")
259-
261+
260262
if missing_envs:
261-
error_msg = ("Missing required environment variables for distributed training:\n" +
262-
"\n".join(missing_envs) +
263-
"\n\nFor Kubernetes multi-node training, ensure you're using enhanced torchrun command:\n"
264-
"torchrun --nnodes=N --nproc_per_node=M --rdzv_backend=c10d \\\n"
265-
" --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT --rdzv_id=$JOB_ID \\\n"
266-
" --node_rank=$NODE_RANK your_script.py\n\n"
267-
"Visit https://www.colossalai.org/ for more information on launching with torch")
263+
error_msg = (
264+
"Missing required environment variables for distributed training:\n"
265+
+ "\n".join(missing_envs)
266+
+ "\n\nFor Kubernetes multi-node training, ensure you're using enhanced torchrun command:\n"
267+
"torchrun --nnodes=N --nproc_per_node=M --rdzv_backend=c10d \\\n"
268+
" --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT --rdzv_id=$JOB_ID \\\n"
269+
" --node_rank=$NODE_RANK your_script.py\n\n"
270+
"Visit https://www.colossalai.org/ for more information on launching with torch"
271+
)
268272
raise RuntimeError(error_msg)
269273

270274
try:
@@ -275,17 +279,17 @@ def launch_from_torch(backend: str = "nccl", seed: int = 1024, verbose: bool = T
275279
port = int(os.environ["MASTER_PORT"])
276280
except ValueError as e:
277281
raise RuntimeError(f"Invalid environment variable value: {e}. All rank and port values must be integers.")
278-
282+
279283
# Additional validation for common misconfigurations
280284
if rank >= world_size:
281285
raise RuntimeError(f"RANK ({rank}) must be less than WORLD_SIZE ({world_size})")
282-
286+
283287
if local_rank < 0:
284288
raise RuntimeError(f"LOCAL_RANK ({local_rank}) must be non-negative")
285-
289+
286290
if port < 1024 or port > 65535:
287291
raise RuntimeError(f"MASTER_PORT ({port}) must be between 1024 and 65535")
288-
292+
289293
# Log distributed training configuration for debugging
290294
if logger and verbose:
291295
logger.info(f"Starting distributed training with configuration:")
@@ -295,7 +299,7 @@ def launch_from_torch(backend: str = "nccl", seed: int = 1024, verbose: bool = T
295299
logger.info(f" MASTER_ADDR: {host}")
296300
logger.info(f" MASTER_PORT: {port}")
297301
logger.info(f" BACKEND: {backend}")
298-
302+
299303
# Log additional environment variables that might be relevant for debugging
300304
debug_envs = ["NODE_RANK", "NCCL_DEBUG", "GLOO_SOCKET_IFNAME", "NCCL_SOCKET_IFNAME", "RDZV_ID"]
301305
for env_var in debug_envs:

colossalai/utils/__init__.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@
1616
# Kubernetes distributed training utilities
1717
try:
1818
from .k8s_distributed import (
19-
validate_k8s_environment,
20-
setup_k8s_networking,
21-
diagnose_distributed_issues,
22-
generate_torchrun_command,
2319
create_k8s_headless_service_yaml,
2420
create_k8s_job_yaml,
21+
diagnose_distributed_issues,
22+
generate_torchrun_command,
23+
setup_k8s_networking,
24+
validate_k8s_environment,
2525
)
26+
2627
_k8s_utils_available = True
27-
28+
2829
__all__ = [
2930
"conditional_context",
3031
"Timer",
@@ -41,15 +42,15 @@
4142
"get_non_persistent_buffers_set",
4243
# K8s distributed training utilities
4344
"validate_k8s_environment",
44-
"setup_k8s_networking",
45+
"setup_k8s_networking",
4546
"diagnose_distributed_issues",
4647
"generate_torchrun_command",
4748
"create_k8s_headless_service_yaml",
4849
"create_k8s_job_yaml",
4950
]
5051
except ImportError:
5152
_k8s_utils_available = False
52-
53+
5354
__all__ = [
5455
"conditional_context",
5556
"Timer",

0 commit comments

Comments
 (0)