Skip to content

Commit db5a7db

Browse files
carmoccaawaelchli
authored andcommitted
Restore support for builds without distributed (#18859)
Co-authored-by: Adrian Wälchli <[email protected]> (cherry picked from commit 78ad390)
1 parent c71970c commit db5a7db

File tree

13 files changed

+35
-19
lines changed

13 files changed

+35
-19
lines changed

src/lightning/fabric/plugins/collectives/torch_collective.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def is_available(cls) -> bool:
167167

168168
@classmethod
169169
def is_initialized(cls) -> bool:
170-
return dist.is_initialized()
170+
return cls.is_available() and dist.is_initialized()
171171

172172
@classmethod
173173
def init_group(cls, **kwargs: Any) -> None:

src/lightning/fabric/strategies/ddp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from lightning.fabric.strategies.strategy import TBroadcast, _BackwardSyncControl
3434
from lightning.fabric.utilities.distributed import (
3535
ReduceOp,
36+
_distributed_is_initialized,
3637
_get_default_process_group_backend_for_device,
3738
_init_dist_connection,
3839
_sync_ddp_if_available,
@@ -143,15 +144,15 @@ def all_reduce(
143144
return tensor
144145

145146
def barrier(self, *args: Any, **kwargs: Any) -> None:
146-
if not torch.distributed.is_initialized():
147+
if not _distributed_is_initialized():
147148
return
148149
if torch.distributed.get_backend() == "nccl":
149150
torch.distributed.barrier(device_ids=self._determine_ddp_device_ids())
150151
else:
151152
torch.distributed.barrier()
152153

153154
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
154-
if not torch.distributed.is_initialized():
155+
if not _distributed_is_initialized():
155156
return obj
156157

157158
obj = [obj]

src/lightning/fabric/strategies/fsdp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
)
5555
from lightning.fabric.utilities.distributed import (
5656
ReduceOp,
57+
_distributed_is_initialized,
5758
_get_default_process_group_backend_for_device,
5859
_init_dist_connection,
5960
_sync_ddp_if_available,
@@ -355,15 +356,15 @@ def all_reduce(
355356
return tensor
356357

357358
def barrier(self, *args: Any, **kwargs: Any) -> None:
358-
if not torch.distributed.is_initialized():
359+
if not _distributed_is_initialized():
359360
return
360361
if torch.distributed.get_backend() == "nccl":
361362
torch.distributed.barrier(device_ids=[self.root_device.index])
362363
else:
363364
torch.distributed.barrier()
364365

365366
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
366-
if not torch.distributed.is_initialized():
367+
if not _distributed_is_initialized():
367368
return obj
368369

369370
obj = [obj]

src/lightning/fabric/utilities/distributed.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _sync_ddp_if_available(
167167
reduced value
168168
169169
"""
170-
if torch.distributed.is_initialized():
170+
if _distributed_is_initialized():
171171
return _sync_ddp(result, group=group, reduce_op=reduce_op)
172172
return result
173173

@@ -244,7 +244,7 @@ def _all_gather_ddp_if_available(
244244
A tensor of shape (world_size, batch, ...)
245245
246246
"""
247-
if not torch.distributed.is_initialized():
247+
if not _distributed_is_initialized():
248248
return tensor
249249

250250
from torch.distributed.nn.functional import all_gather
@@ -373,3 +373,10 @@ def _set_num_threads_if_needed(num_processes: int = 1) -> None:
373373
num_threads = _suggested_max_num_threads(num_processes)
374374
torch.set_num_threads(num_threads)
375375
os.environ["OMP_NUM_THREADS"] = str(num_threads)
376+
377+
378+
def _distributed_is_initialized() -> bool:
379+
# `is_initialized` is only defined conditionally
380+
# https://github.com/pytorch/pytorch/blob/v2.1.0/torch/distributed/__init__.py#L25
381+
# this might happen to MacOS builds from source (default) or any build from source that sets `USE_DISTRIBUTED=0`
382+
return torch.distributed.is_available() and torch.distributed.is_initialized()

src/lightning/pytorch/loops/utilities.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch import Tensor
2121

2222
import lightning.pytorch as pl
23+
from lightning.fabric.utilities.distributed import _distributed_is_initialized
2324
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0, _TORCH_GREATER_EQUAL_1_13
2425
from lightning.fabric.utilities.warnings import PossibleUserWarning
2526
from lightning.pytorch.accelerators.xla import XLAAccelerator
@@ -160,7 +161,7 @@ def _decorator(self: _Loop, *args: Any, **kwargs: Any) -> Any:
160161
if not hasattr(self, "inference_mode"):
161162
raise TypeError(f"`{type(self).__name__}.inference_mode` needs to be defined")
162163
context_manager: Type[ContextManager]
163-
if dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo": # noqa: SIM114
164+
if _distributed_is_initialized() and dist.get_backend() == "gloo": # noqa: SIM114
164165
# gloo backend does not work properly.
165166
# https://github.com/Lightning-AI/lightning/pull/12715/files#r854569110
166167
# TODO: explore why and possibly open an issue in PyTorch repository

src/lightning/pytorch/strategies/ddp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
2929
from lightning.fabric.strategies import _StrategyRegistry
3030
from lightning.fabric.utilities.distributed import (
31+
_distributed_is_initialized,
3132
_get_default_process_group_backend_for_device,
3233
_init_dist_connection,
3334
_sync_ddp_if_available,
@@ -282,7 +283,7 @@ def determine_ddp_device_ids(self) -> Optional[List[int]]:
282283
return [self.root_device.index]
283284

284285
def barrier(self, *args: Any, **kwargs: Any) -> None:
285-
if not torch.distributed.is_initialized():
286+
if not _distributed_is_initialized():
286287
return
287288

288289
if torch.distributed.get_backend() == "nccl":
@@ -291,7 +292,7 @@ def barrier(self, *args: Any, **kwargs: Any) -> None:
291292
torch.distributed.barrier()
292293

293294
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
294-
if not torch.distributed.is_initialized():
295+
if not _distributed_is_initialized():
295296
return obj
296297

297298
obj = [obj]

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
_setup_activation_checkpointing,
4444
)
4545
from lightning.fabric.utilities.distributed import (
46+
_distributed_is_initialized,
4647
_get_default_process_group_backend_for_device,
4748
_init_dist_connection,
4849
_sync_ddp_if_available,
@@ -382,15 +383,15 @@ def model_sharded_context(self) -> Generator[None, None, None]:
382383
yield
383384

384385
def barrier(self, name: Optional[str] = None) -> None:
385-
if not torch.distributed.is_initialized():
386+
if not _distributed_is_initialized():
386387
return
387388
if torch.distributed.get_backend() == "nccl":
388389
torch.distributed.barrier(device_ids=self._determine_device_ids())
389390
else:
390391
torch.distributed.barrier()
391392

392393
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
393-
if not torch.distributed.is_initialized():
394+
if not _distributed_is_initialized():
394395
return obj
395396

396397
obj = [obj]

src/lightning/pytorch/trainer/connectors/logger_connector/result.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from lightning.fabric.utilities import move_data_to_device
2525
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
26+
from lightning.fabric.utilities.distributed import _distributed_is_initialized
2627
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_0
2728
from lightning.pytorch.utilities.data import extract_batch_size
2829
from lightning.pytorch.utilities.exceptions import MisconfigurationException
@@ -425,7 +426,7 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
425426
elif not on_step and result_metric.meta.on_epoch:
426427
if result_metric._computed is None:
427428
should = result_metric.meta.sync.should
428-
if not should and result_metric.is_tensor and torch.distributed.is_initialized():
429+
if not should and result_metric.is_tensor and _distributed_is_initialized():
429430
warning_cache.warn(
430431
f"It is recommended to use `self.log({result_metric.meta.name!r}, ..., sync_dist=True)`"
431432
" when logging on epoch level in distributed setting to accumulate the metric across"

tests/tests_fabric/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import lightning.fabric
2020
import pytest
2121
import torch.distributed
22+
from lightning.fabric.utilities.distributed import _distributed_is_initialized
2223

2324

2425
@pytest.fixture(autouse=True)
@@ -71,7 +72,7 @@ def restore_env_variables():
7172
def teardown_process_group():
7273
"""Ensures that the distributed process group gets closed before the next test runs."""
7374
yield
74-
if torch.distributed.is_available() and torch.distributed.is_initialized():
75+
if _distributed_is_initialized():
7576
torch.distributed.destroy_process_group()
7677

7778

tests/tests_fabric/plugins/collectives/test_single_device.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ def test_can_instantiate_without_args():
1010

1111
def test_create_group():
1212
collective = SingleDeviceCollective()
13-
assert collective.is_available()
1413
assert collective.is_initialized()
1514

1615
with pytest.raises(RuntimeError, match=r"SingleDeviceCollective` does not own a group"):

0 commit comments

Comments
 (0)