Skip to content

Commit 9ca360b

Browse files
debug failing tests for Fabric with ddp_fork on PT 2.8 -> revert #21057 (#21092)
* debug failing tests for Fabric with `ddp_fork` on PT 2.8 * Revert "let `_get_default_process_group_backend_for_device` support more hardware platforms (#21057)" This reverts commit 119a640. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5ac872e commit 9ca360b

File tree

4 files changed

+14
-42
lines changed

4 files changed

+14
-42
lines changed

src/lightning/fabric/plugins/collectives/torch_collective.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class TorchCollective(Collective):
2424
"""
2525

2626
manages_default_group = False
27+
addr_key = "MASTER_ADDR"
28+
port_key = "MASTER_PORT"
2729

2830
def __init__(self) -> None:
2931
if not dist.is_available():
@@ -136,26 +138,21 @@ def setup(self, main_address: Optional[str] = None, main_port: Optional[str] = N
136138
if self.is_initialized():
137139
return self
138140
# maybe set addr
139-
set_addr = False
140-
addr_key = "MASTER_ADDR"
141-
if main_address is not None and addr_key not in os.environ:
142-
os.environ[addr_key] = main_address
143-
set_addr = True
141+
setting_env = []
142+
if main_address is not None and self.addr_key not in os.environ:
143+
os.environ[self.addr_key] = main_address
144+
setting_env.append(self.addr_key)
144145
# maybe set port
145-
set_port = False
146-
port_key = "MASTER_PORT"
147-
if main_port is not None and port_key not in os.environ:
148-
os.environ[port_key] = str(main_port)
149-
set_port = True
146+
if main_port is not None and self.port_key not in os.environ:
147+
os.environ[self.port_key] = str(main_port)
148+
setting_env.append(self.port_key)
150149
# this will `init_group`
151150
super().setup(**kwargs)
152151
# set as a class attribute so any instance can know whether we initialized the default process group
153152
TorchCollective.manages_default_group = True
154153
# cleanup
155-
if set_addr:
156-
os.environ.pop("MASTER_ADDR", None)
157-
if set_port:
158-
os.environ.pop("MASTER_PORT", None)
154+
for kenv in setting_env:
155+
os.environ.pop(kenv, None)
159156
return self
160157

161158
@override

src/lightning/fabric/utilities/distributed.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,7 @@ def _destroy_dist_connection() -> None:
319319

320320

321321
def _get_default_process_group_backend_for_device(device: torch.device) -> str:
322-
"""Return corresponding distributed backend for a given device."""
323-
device_backend_map = torch.distributed.Backend.default_device_backend_map
324-
if device.type in device_backend_map:
325-
return device_backend_map[device.type]
326-
return "gloo"
322+
return "nccl" if device.type == "cuda" else "gloo"
327323

328324

329325
class _DatasetSamplerWrapper(Dataset):

tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def test_memory_sharing_disabled(strategy):
4646

4747
def _test_memory_sharing_disabled(fabric, tensor, model):
4848
is_spawn = fabric.strategy.launcher._start_method == "spawn"
49-
assert not is_spawn or tensor.is_shared()
49+
if is_spawn:
50+
assert tensor.is_shared()
5051
assert not model.layer.weight.is_shared()
5152
assert not model.tied_layer.weight.is_shared()
5253
assert not model.buffer.is_shared()

tests/tests_fabric/utilities/test_distributed.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from lightning.fabric.utilities.distributed import (
1818
_destroy_dist_connection,
1919
_gather_all_tensors,
20-
_get_default_process_group_backend_for_device,
2120
_InfiniteBarrier,
2221
_init_dist_connection,
2322
_is_dtensor,
@@ -244,27 +243,6 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
244243
atexit_mock.register.assert_not_called()
245244

246245

247-
def test_get_default_process_group_backend_for_device():
248-
"""Test that each device type maps to its correct default process group backend."""
249-
# register a custom backend for test
250-
torch.utils.rename_privateuse1_backend("pcu")
251-
252-
def mock_backend(store, group_rank, group_size, timeout):
253-
pass
254-
255-
torch.distributed.Backend.register_backend(
256-
"pccl",
257-
lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout),
258-
devices=["pcu"],
259-
)
260-
261-
# test that the default backend is correctly set for each device
262-
devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("pcu:0")]
263-
backends = ["gloo", "nccl", "pccl"]
264-
for device, backend in zip(devices, backends):
265-
assert _get_default_process_group_backend_for_device(device) == backend
266-
267-
268246
@RunIf(min_torch="2.4")
269247
def test_is_dtensor(monkeypatch):
270248
from torch.distributed._tensor import DTensor

0 commit comments

Comments
 (0)