Skip to content

Commit 401c71d

Browse files
author
Hossein Kavianihamedani
committed
Add NCCL auto-detection for InfiniBand and network interfaces
- Add get_nccl_env_vars() function to automatically detect network configuration - Detects InfiniBand interfaces (ibp*, ib*) and enables them automatically - Falls back to Ethernet (^lo) if no InfiniBand is available - Respects user-set NCCL environment variables - Sets NCCL_SOCKET_IFNAME and NCCL_IB_DISABLE based on actual cluster hardware - Eliminates need for manual NCCL configuration
1 parent 6b65eb4 commit 401c71d

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

src/forge/controller/provisioner.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
import torch
1616

17+
from forge.controller.launcher import BaseLauncher, get_launcher
18+
from forge.env import all_env_vars, FORGE_DISABLE_METRICS
19+
from forge.types import ProcessConfig, ProvisionerConfig
20+
1721
from monarch._src.actor.actor_mesh import ActorMesh
1822
from monarch._src.actor.shape import Extent
1923

@@ -22,10 +26,6 @@
2226
from monarch.tools import commands
2327
from monarch.utils import setup_env_for_distributed
2428

25-
from forge.controller.launcher import BaseLauncher, get_launcher
26-
from forge.env import all_env_vars, FORGE_DISABLE_METRICS
27-
from forge.types import ProcessConfig, ProvisionerConfig
28-
2929
logger = logging.getLogger(__name__)
3030
logger.setLevel(logging.DEBUG)
3131

@@ -134,6 +134,23 @@ async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
134134
await env_setter.set_env.call(env_vars)
135135

136136

137+
async def get_nccl_env_vars() -> dict[str, str]:
138+
"""Get NCCL environment variables by detecting network interfaces."""
139+
if "NCCL_SOCKET_IFNAME" in os.environ and "NCCL_IB_DISABLE" in os.environ:
140+
return {}
141+
142+
try:
143+
interfaces = os.listdir("/sys/class/net/")
144+
ib_interfaces = [i for i in interfaces if i.startswith("ib")]
145+
146+
return {
147+
"NCCL_SOCKET_IFNAME": ",".join(ib_interfaces) if ib_interfaces else "^lo",
148+
"NCCL_IB_DISABLE": "0" if ib_interfaces else "1",
149+
}
150+
except Exception:
151+
return {"NCCL_SOCKET_IFNAME": "^lo", "NCCL_IB_DISABLE": "1"}
152+
153+
137154
class GpuManager:
138155
"""Tracks and assigns GPU devices on a host.
139156
@@ -347,11 +364,16 @@ async def get_proc_mesh(
347364
if with_gpus:
348365
if not addr or not port:
349366
addr, port = await get_remote_info(host_mesh)
350-
gpu_ids = gpu_manager.get_gpus(num_procs)
367+
gpu_ids: list[str] = gpu_manager.get_gpus(num_procs)
351368

369+
# Set PyTorch distributed environment variables
352370
env_vars["MASTER_ADDR"] = addr
353371
env_vars["MASTER_PORT"] = port
354372

373+
# Get NCCL-specific environment variables
374+
nccl_vars = await get_nccl_env_vars()
375+
env_vars.update(nccl_vars)
376+
355377
# Set the PTD world size
356378
world_size = num_procs * (num_hosts or 1)
357379
env_vars["WORLD_SIZE"] = str(world_size)

0 commit comments

Comments
 (0)