Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,10 @@ def _destroy_dist_connection() -> None:


def _get_default_process_group_backend_for_device(device: torch.device) -> str:
return "nccl" if device.type == "cuda" else "gloo"
device_backend_map = torch.distributed.Backend.default_device_backend_map
if device.type in device_backend_map:
return device_backend_map[device.type]
return "gloo"


class _DatasetSamplerWrapper(Dataset):
Expand Down
21 changes: 21 additions & 0 deletions tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from lightning.fabric.utilities.distributed import (
_destroy_dist_connection,
_gather_all_tensors,
_get_default_process_group_backend_for_device,
_InfiniteBarrier,
_init_dist_connection,
_is_dtensor,
Expand Down Expand Up @@ -243,6 +244,26 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
atexit_mock.register.assert_not_called()


def test_get_default_process_group_backend_for_device():
# register a custom backend for test
torch.utils.rename_privateuse1_backend("pcu")

def mock_backend(store, group_rank, group_size, timeout):
pass

torch.distributed.Backend.register_backend(
"pccl",
lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout),
devices=["pcu"],
)

# test that the default backend is correctly set for each device
devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("pcu:0")]
backends = ["gloo", "nccl", "pccl"]
for device, backend in zip(devices, backends):
assert _get_default_process_group_backend_for_device(device) == backend


@RunIf(min_torch="2.4")
def test_is_dtensor(monkeypatch):
from torch.distributed._tensor import DTensor
Expand Down
Loading