Skip to content

Commit 5be1825

Browse files
authored
Merge branch 'master' into ioannis@18861-CSVLogger-fails-on-remote-fs
2 parents 509e639 + 3c81316 commit 5be1825

File tree

6 files changed

+42
-5
lines changed

6 files changed

+42
-5
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222

2323
### Changed
2424

25-
-
25+
- let `_get_default_process_group_backend_for_device` support more hardware platforms (
26+
[#21057](https://github.com/Lightning-AI/pytorch-lightning/pull/21057), [#21093](https://github.com/Lightning-AI/pytorch-lightning/pull/21093))
2627

2728

2829
### Fixed

src/lightning/fabric/strategies/ddp.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,17 @@ def barrier(self, *args: Any, **kwargs: Any) -> None:
160160
if torch.distributed.get_backend() == "nccl":
161161
torch.distributed.barrier(device_ids=self._determine_ddp_device_ids())
162162
else:
163-
torch.distributed.barrier()
163+
# Handle PyTorch bug where barrier() fails on CPU with "PrivateUse1HooksInterface" error
164+
try:
165+
torch.distributed.barrier()
166+
except RuntimeError as e:
167+
if "PrivateUse1HooksInterface" in str(e):
168+
# Fallback: Use all_reduce as barrier - all processes must participate
169+
# This achieves the same synchronization effect as barrier()
170+
dummy_tensor = torch.tensor(0.0, device=self.root_device)
171+
torch.distributed.all_reduce(dummy_tensor)
172+
else:
173+
raise
164174

165175
@override
166176
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:

src/lightning/fabric/utilities/distributed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,11 @@ def _destroy_dist_connection() -> None:
319319

320320

321321
def _get_default_process_group_backend_for_device(device: torch.device) -> str:
322-
return "nccl" if device.type == "cuda" else "gloo"
322+
"""Return corresponding distributed backend for a given device."""
323+
device_backend_map = torch.distributed.Backend.default_device_backend_map
324+
if device.type in device_backend_map:
325+
return device_backend_map[device.type]
326+
return "gloo"
323327

324328

325329
class _DatasetSamplerWrapper(Dataset):

tests/tests_fabric/utilities/test_distributed.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from lightning.fabric.utilities.distributed import (
1818
_destroy_dist_connection,
1919
_gather_all_tensors,
20+
_get_default_process_group_backend_for_device,
2021
_InfiniteBarrier,
2122
_init_dist_connection,
2223
_is_dtensor,
@@ -243,6 +244,27 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
243244
atexit_mock.register.assert_not_called()
244245

245246

247+
def test_get_default_process_group_backend_for_device():
248+
"""Test that each device type maps to its correct default process group backend."""
249+
# register a custom backend for test
250+
torch.utils.rename_privateuse1_backend("pcu")
251+
252+
def mock_backend(store, group_rank, group_size, timeout):
253+
pass
254+
255+
torch.distributed.Backend.register_backend(
256+
"pccl",
257+
lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout),
258+
devices=["pcu"],
259+
)
260+
261+
# test that the default backend is correctly set for each device
262+
devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("pcu:0")]
263+
backends = ["gloo", "nccl", "pccl"]
264+
for device, backend in zip(devices, backends):
265+
assert _get_default_process_group_backend_for_device(device) == backend
266+
267+
246268
@RunIf(min_torch="2.4")
247269
def test_is_dtensor(monkeypatch):
248270
from torch.distributed._tensor import DTensor

tests/tests_fabric/utilities/test_spike.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
3030
)
3131

3232

33-
@pytest.mark.flaky(max_runs=3)
33+
@pytest.mark.flaky(reruns=3)
3434
@pytest.mark.parametrize(
3535
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
3636
# NOTE FOR ALL FOLLOWING TESTS:

tests/tests_pytorch/callbacks/test_spike.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
4848
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
4949

5050

51-
@pytest.mark.flaky(max_runs=3)
51+
@pytest.mark.flaky(reruns=3)
5252
@pytest.mark.parametrize(
5353
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
5454
# NOTE FOR ALL FOLLOWING TESTS:

0 commit comments

Comments
 (0)