2929from lightning .pytorch import cli_lightning_logo
3030from lightning .pytorch .core import LightningModule
3131from lightning .pytorch .demos .mnist_datamodule import MNISTDataModule
32+ from lightning .pytorch .strategies .ddp import MultiModelDDPStrategy
3233from lightning .pytorch .trainer import Trainer
3334from lightning .pytorch .utilities .imports import _TORCHVISION_AVAILABLE
3435
@@ -209,10 +210,18 @@ def main(args: Namespace) -> None:
209210 # ------------------------
210211 # 2 INIT TRAINER
211212 # ------------------------
212- # If use distributed training PyTorch recommends to use DistributedDataParallel.
213- # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
214213 dm = MNISTDataModule ()
215- trainer = Trainer (accelerator = "gpu" , devices = 1 )
214+
215+ if args .use_ddp :
216+ # `MultiModelDDPStrategy` is critical for multi-gpu GAN training
217+ # There are two ways to run training codes with existed `DDPStrategy`:
218+ # 1) Activate `find_unused_parameters` option
219+ # 2) change from self.manual_backward(loss) to loss.backward()
220+ # Neither of them is desirable.
221+ trainer = Trainer (accelerator = "gpu" , devices = - 1 , strategy = MultiModelDDPStrategy ())
222+ else :
223+ # If you want to run on a single GPU, you can use the default strategy.
224+ trainer = Trainer (accelerator = "gpu" , devices = 1 )
216225
217226 # ------------------------
218227 # 3 START TRAINING
@@ -229,6 +238,8 @@ def main(args: Namespace) -> None:
229238 parser .add_argument ("--b1" , type = float , default = 0.5 , help = "adam: decay of first order momentum of gradient" )
230239 parser .add_argument ("--b2" , type = float , default = 0.999 , help = "adam: decay of second order momentum of gradient" )
231240 parser .add_argument ("--latent_dim" , type = int , default = 100 , help = "dimensionality of the latent space" )
241+ parser .add_argument ("--use_ddp" , action = "store_true" , help = "distributed strategy to use" )
242+
232243 args = parser .parse_args ()
233244
234245 main (args )
0 commit comments