Skip to content

Commit e6b061a

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 00727e8 commit e6b061a

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

examples/pytorch/domain_templates/generative_adversarial_net_ddp.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,27 @@
1818
tensorboard --logdir default
1919
2020
"""
21+
2122
import math
23+
24+
# ! TESTING
25+
import os
26+
import sys
2227
from argparse import ArgumentParser, Namespace
2328

2429
import torch
2530
import torch.nn as nn
2631
import torch.nn.functional as F
2732

28-
# ! TESTING
29-
import os
30-
import sys
31-
32-
sys.path.append(os.path.join(os.getcwd(), "src")) # noqa: E402
33+
sys.path.append(os.path.join(os.getcwd(), "src"))
3334
# ! TESTING
3435

3536
from lightning.pytorch import cli_lightning_logo
3637
from lightning.pytorch.core import LightningModule
3738
from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule
39+
from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy
3840
from lightning.pytorch.trainer import Trainer
3941
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
40-
from lightning.pytorch.strategies.ddp import DDPStrategy, MultiModelDDPStrategy
4142

4243
if _TORCHVISION_AVAILABLE:
4344
import torchvision

src/lightning/pytorch/strategies/ddp.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def __init__(
107107
@property
108108
def is_distributed(self) -> bool: # pragma: no-cover
109109
"""Legacy property kept for backwards compatibility."""
110-
rank_zero_deprecation(f"`{type(self).__name__}.is_distributed` is deprecated. Use is discouraged.", stacklevel=6)
110+
rank_zero_deprecation(
111+
f"`{type(self).__name__}.is_distributed` is deprecated. Use is discouraged.", stacklevel=6
112+
)
111113
return True
112114

113115
@property
@@ -227,7 +229,9 @@ def _register_ddp_hooks(self) -> None:
227229
def _enable_model_averaging(self) -> None:
228230
log.debug(f"{self.__class__.__name__}: reinitializing optimizers with post localSGD")
229231
if self._model_averaging_period is None:
230-
raise ValueError("Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy.")
232+
raise ValueError(
233+
"Post-localSGD algorithm is used, but model averaging period is not provided to DDP strategy."
234+
)
231235
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
232236

233237
for optimizer in self.optimizers:
@@ -236,7 +240,10 @@ def _enable_model_averaging(self) -> None:
236240

237241
is_distributed_optimizer = isinstance(optimizer, DistributedOptimizer) if not _IS_WINDOWS else False
238242
if isinstance(optimizer, (ZeroRedundancyOptimizer, PostLocalSGDOptimizer)) or is_distributed_optimizer:
239-
raise ValueError(f"Currently model averaging cannot work with a distributed optimizer of type " f"{optimizer.__class__.__name__}.")
243+
raise ValueError(
244+
f"Currently model averaging cannot work with a distributed optimizer of type "
245+
f"{optimizer.__class__.__name__}."
246+
)
240247

241248
assert self._ddp_comm_state is not None
242249
self._model_averager = torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager(
@@ -316,7 +323,9 @@ def model_to_device(self) -> None:
316323
self.model.to(self.root_device)
317324

318325
@override
319-
def reduce(self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean") -> Tensor:
326+
def reduce(
327+
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
328+
) -> Tensor:
320329
"""Reduces a tensor from several distributed processes to one aggregated tensor.
321330
322331
Args:

0 commit comments

Comments
 (0)