-
Notifications
You must be signed in to change notification settings - Fork 268
Description
Summary
nemo_rl/distributed/worker_groups.py (line ~347) unconditionally removes RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES from worker environment variables:
# Remove Ray-specific environment variables, let the worker itself set them.
worker_env_vars.pop("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", None)This forces Ray to set per-actor CUDA_VISIBLE_DEVICES masking (e.g., CUDA_VISIBLE_DEVICES=3 for the 4th GPU), which triggers three confirmed NCCL bugs on H200/NVSwitch (P5en) hardware. The result is that multi-node GRPO training is completely broken on H200.
NCCL Bugs Triggered by GPU Masking
When Ray masks CUDA_VISIBLE_DEVICES to a single GPU per actor, NCCL's internal device indexing diverges from the physical GPU topology. This triggers:
-
cuMem import penalty (NVIDIA/nccl#1749) -- Confirmed by NVIDIA engineer.
p2pMap()insrc/transport/p2p.cciterates over all devices when importing cuMem handles with non-overlappingCUDA_VISIBLE_DEVICES. Causes 3,660ms first-operation penalty (vs 1.5ms with torchrun). -
NVLS rank ordering corruption (NVIDIA/nccl#1906) -- Confirmed by NVIDIA engineer.
src/transport/nvls.ccallgather is missing a user rank table when GPU indices are permuted byCUDA_VISIBLE_DEVICES. Causes hang or silent data corruption on NVSwitch systems. Only affects NVSwitch (H200 P5en), not NVLink-only (H100 P5). -
Multi-channel P2P hang at >8M elements -- Even with
NCCL_CUMEM_ENABLE=0andNCCL_NVLS_ENABLE=0, AllReduce hangs for tensors larger than ~32MB (8-12M float32 elements). This appears to be a separate multi-channel issue triggered by the same GPU masking.
Benchmarks: Same Hardware, Same NCCL, Different Results
| Method | AllReduce 4KB | AllReduce 933MB | Notes |
|---|---|---|---|
| torchrun (no GPU masking) | 1.5ms | 1.5ms | All sizes work perfectly |
| Ray (GPU masking forced) | 3,660ms | HANGS forever | 2400x slower, then hangs |
Both tests run on the same P5en.48xlarge nodes (8x H200, NVSwitch), same NCCL 2.27.5, same EFA networking, same container image. The ONLY difference is whether CUDA_VISIBLE_DEVICES is masked per-process.
On P5.48xlarge (H100, NVLink PXN, no NVSwitch), Ray with GPU masking works fine -- confirming this is specific to NVSwitch topology.
Proposed Fix
Make RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES configurable instead of unconditionally removing it. For example:
# Allow users to opt out of Ray GPU masking (needed for H200/NVSwitch)
if not os.environ.get("NEMO_RL_KEEP_RAY_NOSET_CVD"):
worker_env_vars.pop("RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", None)Or expose it as a configuration parameter in the worker group config. When RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1 is preserved, each worker sees all 8 GPUs and must call torch.cuda.set_device(local_rank) explicitly -- which NeMo RL workers already do.
Current Workaround
The only workaround is to patch the installed worker_groups.py in the container's venv to comment out the .pop() line, which is fragile and breaks on upgrades.
Environment
- Hardware: 4x P5en.48xlarge (8x H200 per node, NVSwitch)
- NCCL: 2.27.5 with aws-ofi-nccl 1.18.0
- NeMo RL: Latest main branch
- Ray: 2.44.1
- Network: EFA with Libfabric 2.4
Related Issues
- NCCL AllReduce 2400x slower via Ray actors vs torchrun on H200 (NVSwitch) multi-node ray-project/ray#61073 -- NCCL AllReduce 2400x slower via Ray on H200 (upstream Ray issue)
- Multi-node GRPO training fails on P5en.48xlarge (H200) due to Ray+NCCL issue #1961 -- Multi-node GRPO fails on P5en (broader tracking issue)
- slow firsttime op with NCCL_CUMEM_ENABLE=1 and non-default CUDA_VISIBLE_DEVICES in ray workers NVIDIA/nccl#1749 -- cuMem import penalty (confirmed by NVIDIA)
- [Issue]:
all_gatherreturns tensors in incorrect rank order for certain permutation ofCUDA_VISIBLE_DEVICESandNCCL_ALGO=NVLS(regression in NCCL 2.26.2; NCCL 2.21.5 OK) NVIDIA/nccl#1906 -- NVLS rank ordering (confirmed by NVIDIA)