Skip to content

Commit c3614f1

Browse files
authored
Fix: skip importing DistributedOptimizer for Windows (#10071)
1 parent 454e93b commit c3614f1

File tree

5 files changed

+12
-5
lines changed

5 files changed

+12
-5
lines changed

docs/source/advanced/mixed_precision.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ BFloat16 Mixed precision is similar to FP16 mixed precision, however we maintain
5050
Since BFloat16 is more stable than FP16 during training, we do not need to worry about any gradient scaling or nan gradient values that comes with using FP16 mixed precision.
5151

5252
.. testcode::
53-
:skipif: not _TORCH_GREATER_EQUAL_DEV_1_10
53+
:skipif: not _TORCH_GREATER_EQUAL_DEV_1_10 or not torch.cuda.is_available()
5454

5555
Trainer(gpus=1, precision="bf16")
5656

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from pytorch_lightning.core.saving import ModelIO
3939
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
4040
from pytorch_lightning.utilities import (
41+
_IS_WINDOWS,
4142
_TORCH_GREATER_EQUAL_DEV_1_10,
4243
GradClipAlgorithmType,
4344
rank_zero_deprecation,
@@ -2041,7 +2042,7 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
20412042
20422043
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
20432044
"""
2044-
if not _TORCH_GREATER_EQUAL_DEV_1_10:
2045+
if not _TORCH_GREATER_EQUAL_DEV_1_10 or _IS_WINDOWS:
20452046
return
20462047

20472048
from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from pytorch_lightning.utilities import (
4343
_FAIRSCALE_AVAILABLE,
4444
_HYDRA_AVAILABLE,
45+
_IS_WINDOWS,
4546
_TORCH_GREATER_EQUAL_1_7,
4647
_TORCH_GREATER_EQUAL_1_8,
4748
_TORCH_GREATER_EQUAL_1_9,
@@ -57,7 +58,9 @@
5758
from pytorch_lightning.utilities.types import STEP_OUTPUT
5859

5960
if _TORCH_GREATER_EQUAL_1_10:
60-
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
61+
if not _IS_WINDOWS:
62+
from torch.distributed.optim import DistributedOptimizer
63+
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer
6164

6265
if _FAIRSCALE_AVAILABLE:
6366
from fairscale.optim import OSS
@@ -333,8 +336,9 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
333336
if isinstance(optimizer, LightningOptimizer):
334337
optimizer = optimizer._optimizer
335338

339+
is_distributed_optimizer = isinstance(optimizer, DistributedOptimizer) if not _IS_WINDOWS else False
336340
if (
337-
isinstance(optimizer, DistributedOptimizer)
341+
is_distributed_optimizer
338342
or isinstance(optimizer, ZeroRedundancyOptimizer)
339343
or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))
340344
):

pytorch_lightning/utilities/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
_HYDRA_EXPERIMENTAL_AVAILABLE,
3939
_IPU_AVAILABLE,
4040
_IS_INTERACTIVE,
41+
_IS_WINDOWS,
4142
_JSONARGPARSE_AVAILABLE,
4243
_module_available,
4344
_OMEGACONF_AVAILABLE,

tests/core/test_lightning_module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from pytorch_lightning import Trainer
2323
from pytorch_lightning.loggers import TensorBoardLogger
24-
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10
24+
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_DEV_1_10
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626
from tests.helpers import BoringModel
2727
from tests.helpers.runif import RunIf
@@ -315,6 +315,7 @@ def __init__(self, spec):
315315
@pytest.mark.skipif(
316316
not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Test requires the torch version to support `ShardedTensor`"
317317
)
318+
@pytest.mark.skipif(_IS_WINDOWS, reason="Not supported on Windows")
318319
def test_sharded_tensor_state_dict(tmpdir, single_process_pg):
319320
spec = dist._sharding_spec.ChunkShardingSpec(
320321
dim=0,

0 commit comments

Comments
 (0)