diff --git a/src/lightning/fabric/plugins/collectives/torch_collective.py b/src/lightning/fabric/plugins/collectives/torch_collective.py index 883380bb881aa..182e75f4583ef 100644 --- a/src/lightning/fabric/plugins/collectives/torch_collective.py +++ b/src/lightning/fabric/plugins/collectives/torch_collective.py @@ -24,6 +24,8 @@ class TorchCollective(Collective): """ manages_default_group = False + addr_key = "MASTER_ADDR" + port_key = "MASTER_PORT" def __init__(self) -> None: if not dist.is_available(): @@ -136,26 +138,21 @@ def setup(self, main_address: Optional[str] = None, main_port: Optional[str] = N if self.is_initialized(): return self # maybe set addr - set_addr = False - addr_key = "MASTER_ADDR" - if main_address is not None and addr_key not in os.environ: - os.environ[addr_key] = main_address - set_addr = True + setting_env = [] + if main_address is not None and self.addr_key not in os.environ: + os.environ[self.addr_key] = main_address + setting_env.append(self.addr_key) # maybe set port - set_port = False - port_key = "MASTER_PORT" - if main_port is not None and port_key not in os.environ: - os.environ[port_key] = str(main_port) - set_port = True + if main_port is not None and self.port_key not in os.environ: + os.environ[self.port_key] = str(main_port) + setting_env.append(self.port_key) # this will `init_group` super().setup(**kwargs) # set as a class attribute so any instance can know whether we initialized the default process group TorchCollective.manages_default_group = True # cleanup - if set_addr: - os.environ.pop("MASTER_ADDR", None) - if set_port: - os.environ.pop("MASTER_PORT", None) + for kenv in setting_env: + os.environ.pop(kenv, None) return self @override diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index 500f3a3e2aa92..ec4eb261f2d3e 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -319,11 +319,7 @@ def _destroy_dist_connection() -> None: def _get_default_process_group_backend_for_device(device: torch.device) -> str: - """Return corresponding distributed backend for a given device.""" - device_backend_map = torch.distributed.Backend.default_device_backend_map - if device.type in device_backend_map: - return device_backend_map[device.type] - return "gloo" + return "nccl" if device.type == "cuda" else "gloo" class _DatasetSamplerWrapper(Dataset): diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py index 2abfe73c92dec..2eaf1d23572c8 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py @@ -46,7 +46,8 @@ def test_memory_sharing_disabled(strategy): def _test_memory_sharing_disabled(fabric, tensor, model): is_spawn = fabric.strategy.launcher._start_method == "spawn" - assert not is_spawn or tensor.is_shared() + if is_spawn: + assert tensor.is_shared() assert not model.layer.weight.is_shared() assert not model.tied_layer.weight.is_shared() assert not model.buffer.is_shared() diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index 51c4b320d5525..d65eaa810ff4d 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -17,7 +17,6 @@ from lightning.fabric.utilities.distributed import ( _destroy_dist_connection, _gather_all_tensors, - _get_default_process_group_backend_for_device, _InfiniteBarrier, _init_dist_connection, _is_dtensor, @@ -244,27 +243,6 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock): atexit_mock.register.assert_not_called() -def test_get_default_process_group_backend_for_device(): - """Test that each device type maps to its correct default process group backend.""" - # register a custom backend for test - torch.utils.rename_privateuse1_backend("pcu") - - def mock_backend(store, group_rank, group_size, timeout): - pass - - torch.distributed.Backend.register_backend( - "pccl", - lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout), - devices=["pcu"], - ) - - # test that the default backend is correctly set for each device - devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("pcu:0")] - backends = ["gloo", "nccl", "pccl"] - for device, backend in zip(devices, backends): - assert _get_default_process_group_backend_for_device(device) == backend - - @RunIf(min_torch="2.4") def test_is_dtensor(monkeypatch): from torch.distributed._tensor import DTensor