diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 9f55de94d9135..b4611e165f917 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -22,7 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- let `_get_default_process_group_backend_for_device` support more hardware platforms ( + [#21057](https://github.com/Lightning-AI/pytorch-lightning/pull/21057), [#21093](https://github.com/Lightning-AI/pytorch-lightning/pull/21093)) ### Fixed diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index e826b910c16d3..af182ad7f422f 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -160,7 +160,17 @@ def barrier(self, *args: Any, **kwargs: Any) -> None: if torch.distributed.get_backend() == "nccl": torch.distributed.barrier(device_ids=self._determine_ddp_device_ids()) else: - torch.distributed.barrier() + # Handle PyTorch bug where barrier() fails on CPU with "PrivateUse1HooksInterface" error + try: + torch.distributed.barrier() + except RuntimeError as e: + if "PrivateUse1HooksInterface" in str(e): + # Fallback: Use all_reduce as barrier - all processes must participate + # This achieves the same synchronization effect as barrier() + dummy_tensor = torch.tensor(0.0, device=self.root_device) + torch.distributed.all_reduce(dummy_tensor) + else: + raise @override def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index ec4eb261f2d3e..500f3a3e2aa92 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -319,7 +319,11 @@ def _destroy_dist_connection() -> None: def _get_default_process_group_backend_for_device(device: torch.device) -> str: - return "nccl" if device.type == "cuda" else "gloo" + """Return corresponding distributed backend for a given device.""" + device_backend_map = torch.distributed.Backend.default_device_backend_map + if device.type in device_backend_map: + return device_backend_map[device.type] + return "gloo" class _DatasetSamplerWrapper(Dataset): diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index d65eaa810ff4d..51c4b320d5525 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -17,6 +17,7 @@ from lightning.fabric.utilities.distributed import ( _destroy_dist_connection, _gather_all_tensors, + _get_default_process_group_backend_for_device, _InfiniteBarrier, _init_dist_connection, _is_dtensor, @@ -243,6 +244,27 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock): atexit_mock.register.assert_not_called() +def test_get_default_process_group_backend_for_device(): + """Test that each device type maps to its correct default process group backend.""" + # register a custom backend for test + torch.utils.rename_privateuse1_backend("pcu") + + def mock_backend(store, group_rank, group_size, timeout): + pass + + torch.distributed.Backend.register_backend( + "pccl", + lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout), + devices=["pcu"], + ) + + # test that the default backend is correctly set for each device + devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("pcu:0")] + backends = ["gloo", "nccl", "pccl"] + for device, backend in zip(devices, backends): + assert _get_default_process_group_backend_for_device(device) == backend + + @RunIf(min_torch="2.4") def test_is_dtensor(monkeypatch): from torch.distributed._tensor import DTensor