Skip to content

Commit 1214198

Browse files
authored
[c10d] Fix extra CUDA context created by barrier (pytorch#152834)
Fixes pytorch#149119. In ProcessGroup.hpp, we create a dummy tensor for dispatching. This requires a correct device index. This PR uses `device_id` given by user when calling `init_process_group`. This PR also uses `torch._C._get_accelerator()` to determine the device type. ghstack-source-id: 96c32b9 Pull Request resolved: pytorch#149144
1 parent 790cc2f commit 1214198

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3555,17 +3555,6 @@ def test_nccl_barrier_device_ids(self):
35553555

35563556
c10d.barrier(device_ids=[self.rank])
35573557

3558-
@requires_nccl()
3559-
@skip_if_lt_x_gpu(2)
3560-
def test_nccl_barrier_device_ids_function_argument(self):
3561-
store = c10d.FileStore(self.file_name, self.world_size)
3562-
c10d.init_process_group(
3563-
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
3564-
)
3565-
3566-
with self.assertRaisesRegex(TypeError, "Invalid function argument"):
3567-
c10d.barrier(device_ids=self.rank)
3568-
35693558
@requires_nccl()
35703559
@skip_if_lt_x_gpu(2)
35713560
def test_unwaited(self) -> None:

torch/distributed/distributed_c10d.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4596,29 +4596,42 @@ def barrier(
45964596
group (ProcessGroup, optional): The process group to work on. If None,
45974597
the default process group will be used.
45984598
async_op (bool, optional): Whether this op should be an async op
4599-
device_ids ([int], optional): List of device/GPU ids.
4599+
device_ids ([int], optional): List of device/GPU ids. Only one id is expected.
46004600
46014601
Returns:
46024602
Async work handle, if async_op is set to True.
46034603
None, if not async_op or if not part of the group
46044604
46054605
.. note:: `ProcessGroupNCCL` now blocks the cpu thread till the completion of the barrier collective.
46064606
"""
4607+
group = group or _get_default_group()
4608+
46074609
if _rank_not_in_group(group):
46084610
_warn_not_in_group("barrier")
46094611
return
46104612

46114613
opts = BarrierOptions()
4612-
opts.device = torch.device(_get_object_coll_device(group))
4613-
if device_ids is not None:
4614-
if isinstance(device_ids, list):
4615-
opts.device_ids = device_ids
4616-
else:
4617-
raise TypeError(
4618-
"Invalid function argument: device_ids type should be List[int]"
4619-
)
4614+
# Detect the accelerator on the machine. If no accelerator is available, it
4615+
# returns CPU.
4616+
device = torch._C._get_accelerator()
4617+
if isinstance(device_ids, list):
4618+
opts.device_ids = device_ids
4619+
# use only the first device id
4620+
opts.device = torch.device(device.type, device_ids[0])
4621+
elif getattr(group, "bound_device_id", None) is not None:
4622+
# Use device id from `init_process_group(device_id=...)`
4623+
opts.device = group.bound_device_id # type: ignore[assignment]
4624+
elif device.type == "cpu" or _get_object_coll_device(group) == "cpu":
4625+
opts.device = torch.device("cpu")
4626+
else:
4627+
# Use the current device set by the user. If user did not set any, this
4628+
# may use default device 0, causing issues like hang or all processes
4629+
# creating context on device 0.
4630+
opts.device = device
4631+
warnings.warn( # warn only once
4632+
"No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. "
4633+
)
46204634

4621-
group = group or _get_default_group()
46224635
work = group.barrier(opts=opts)
46234636

46244637
if async_op:

0 commit comments

Comments
 (0)