|
14 | 14 |
|
15 | 15 | import torch |
16 | 16 |
|
| 17 | +from forge.controller.launcher import BaseLauncher, get_launcher |
| 18 | +from forge.env import all_env_vars, FORGE_DISABLE_METRICS |
| 19 | +from forge.types import ProcessConfig, ProvisionerConfig |
| 20 | + |
17 | 21 | from monarch._src.actor.actor_mesh import ActorMesh |
18 | 22 | from monarch._src.actor.shape import Extent |
19 | 23 |
|
|
22 | 26 | from monarch.tools import commands |
23 | 27 | from monarch.utils import setup_env_for_distributed |
24 | 28 |
|
25 | | -from forge.controller.launcher import BaseLauncher, get_launcher |
26 | | -from forge.env import all_env_vars, FORGE_DISABLE_METRICS |
27 | | -from forge.types import ProcessConfig, ProvisionerConfig |
28 | | - |
29 | 29 | logger = logging.getLogger(__name__) |
30 | 30 | logger.setLevel(logging.DEBUG) |
31 | 31 |
|
@@ -134,6 +134,23 @@ async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]): |
134 | 134 | await env_setter.set_env.call(env_vars) |
135 | 135 |
|
136 | 136 |
|
| 137 | +async def get_nccl_env_vars() -> dict[str, str]: |
| 138 | + """Get NCCL environment variables by detecting network interfaces.""" |
| 139 | + if "NCCL_SOCKET_IFNAME" in os.environ and "NCCL_IB_DISABLE" in os.environ: |
| 140 | + return {} |
| 141 | + |
| 142 | + try: |
| 143 | + interfaces = os.listdir("/sys/class/net/") |
| 144 | + ib_interfaces = [i for i in interfaces if i.startswith("ib")] |
| 145 | + |
| 146 | + return { |
| 147 | + "NCCL_SOCKET_IFNAME": ",".join(ib_interfaces) if ib_interfaces else "^lo", |
| 148 | + "NCCL_IB_DISABLE": "0" if ib_interfaces else "1", |
| 149 | + } |
| 150 | + except Exception: |
| 151 | + return {"NCCL_SOCKET_IFNAME": "^lo", "NCCL_IB_DISABLE": "1"} |
| 152 | + |
| 153 | + |
137 | 154 | class GpuManager: |
138 | 155 | """Tracks and assigns GPU devices on a host. |
139 | 156 |
|
@@ -347,11 +364,16 @@ async def get_proc_mesh( |
347 | 364 | if with_gpus: |
348 | 365 | if not addr or not port: |
349 | 366 | addr, port = await get_remote_info(host_mesh) |
350 | | - gpu_ids = gpu_manager.get_gpus(num_procs) |
| 367 | + gpu_ids: list[str] = gpu_manager.get_gpus(num_procs) |
351 | 368 |
|
| 369 | + # Set PyTorch distributed environment variables |
352 | 370 | env_vars["MASTER_ADDR"] = addr |
353 | 371 | env_vars["MASTER_PORT"] = port |
354 | 372 |
|
| 373 | + # Get NCCL-specific environment variables |
| 374 | + nccl_vars = await get_nccl_env_vars() |
| 375 | + env_vars.update(nccl_vars) |
| 376 | + |
355 | 377 | # Set the PTD world size |
356 | 378 | world_size = num_procs * (num_hosts or 1) |
357 | 379 | env_vars["WORLD_SIZE"] = str(world_size) |
|
0 commit comments