Skip to content

Commit 4b22284

Browse files
committed
misc: merge gan training example, add docstring of MultiModelDDPStrategy
1 parent 8b1fe23 commit 4b22284

File tree

2 files changed

+14
-247
lines changed

2 files changed

+14
-247
lines changed

examples/pytorch/domain_templates/generative_adversarial_net.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from lightning.pytorch import cli_lightning_logo
3030
from lightning.pytorch.core import LightningModule
3131
from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule
32+
from lightning.pytorch.strategies.ddp import MultiModelDDPStrategy
3233
from lightning.pytorch.trainer import Trainer
3334
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
3435

@@ -209,10 +210,18 @@ def main(args: Namespace) -> None:
209210
# ------------------------
210211
# 2 INIT TRAINER
211212
# ------------------------
212-
# If use distributed training PyTorch recommends to use DistributedDataParallel.
213-
# See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
214213
dm = MNISTDataModule()
215-
trainer = Trainer(accelerator="gpu", devices=1)
214+
215+
if args.use_ddp:
216+
# `MultiModelDDPStrategy` is critical for multi-gpu GAN training
217+
# There are two ways to run training codes with existed `DDPStrategy`:
218+
# 1) Activate `find_unused_parameters` option
219+
# 2) change from self.manual_backward(loss) to loss.backward()
220+
# Neither of them is desirable.
221+
trainer = Trainer(accelerator="gpu", devices=-1, strategy=MultiModelDDPStrategy())
222+
else:
223+
# If you want to run on a single GPU, you can use the default strategy.
224+
trainer = Trainer(accelerator="gpu", devices=1)
216225

217226
# ------------------------
218227
# 3 START TRAINING
@@ -229,6 +238,8 @@ def main(args: Namespace) -> None:
229238
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
230239
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of second order momentum of gradient")
231240
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
241+
parser.add_argument("--use_ddp", action="store_true", help="distributed strategy to use")
242+
232243
args = parser.parse_args()
233244

234245
main(args)

examples/pytorch/domain_templates/generative_adversarial_net_ddp.py

Lines changed: 0 additions & 244 deletions
This file was deleted.

0 commit comments

Comments
 (0)