Skip to content

Commit 790cc2f

Browse files
authored
[c10d] Add more tests to prevent extra context (pytorch#154179)
ghstack-source-id: da61972 Pull-Request-resolved: pytorch#154174
1 parent 62ea99a commit 790cc2f

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,56 @@ def test_extra_cuda_context(self):
671671
except ModuleNotFoundError:
672672
self._helper_test_extra_cuda_context_by_memory()
673673

674+
@requires_nccl()
675+
@skip_if_lt_x_gpu(2)
676+
def test_extra_cuda_context_sync_ops(self):
677+
# Loop a bunch of sync ops and see if any of them creates extra context.
678+
# Requires nvml to check number of processes resident on a device.
679+
try:
680+
import pynvml
681+
682+
pynvml.nvmlInit()
683+
except Exception:
684+
self.skipTest("pynvml not available")
685+
686+
# Check if non-0 ranks would create extra CUDA context on device 0
687+
store = c10d.FileStore(self.file_name, self.world_size)
688+
device = torch.device(f"cuda:{self.rank:d}")
689+
c10d.init_process_group(
690+
backend="nccl",
691+
store=store,
692+
rank=self.rank,
693+
world_size=self.world_size,
694+
device_id=device,
695+
)
696+
697+
x = torch.empty((1,), device=device)
698+
y = torch.empty((self.world_size,), device=device)
699+
700+
c10d.all_reduce(x)
701+
c10d.reduce(x, dst=0)
702+
c10d.broadcast(x, src=0)
703+
c10d.all_gather_into_tensor(y, x)
704+
c10d.reduce_scatter_tensor(x, y)
705+
c10d.barrier()
706+
707+
# Wait a bit for remote processes to touch my device
708+
if self.rank == 0:
709+
time.sleep(5)
710+
711+
handle = pynvml.nvmlDeviceGetHandleByIndex(self.rank)
712+
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
713+
nprocs = len(processes)
714+
715+
# Don't exit till rank 0 is done with the nvml detection
716+
c10d.barrier()
717+
c10d.destroy_process_group()
718+
self.assertLessEqual(
719+
nprocs,
720+
1,
721+
f"Found {nprocs} processes creating contexts on {device}, expecting 1 at most",
722+
)
723+
674724
@requires_nccl()
675725
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
676726
def test_destruct_before_terminate_pg(self):

0 commit comments

Comments
 (0)