Skip to content

Commit 8b1fe23

Browse files
committed
misc: resolve some review comments for product consistency
1 parent ece7d38 commit 8b1fe23

File tree

2 files changed

+17
-36
lines changed

2 files changed

+17
-36
lines changed

examples/pytorch/domain_templates/generative_adversarial_net_ddp.py

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,12 @@
2020
"""
2121

2222
import math
23-
24-
# ! TESTING
25-
import os
26-
import sys
2723
from argparse import ArgumentParser, Namespace
2824

2925
import torch
3026
import torch.nn as nn
3127
import torch.nn.functional as F
3228

33-
sys.path.append(os.path.join(os.getcwd(), "src"))
34-
# ! TESTING
35-
3629
from lightning.pytorch import cli_lightning_logo
3730
from lightning.pytorch.core import LightningModule
3831
from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule
@@ -44,7 +37,7 @@
4437
import torchvision
4538

4639

47-
def _block(in_feat: int, out_feat: int, normalize: bool = True):
40+
def _block(in_feat: int, out_feat: int, normalize: bool = True) -> list:
4841
layers = [nn.Linear(in_feat, out_feat)]
4942
if normalize:
5043
layers.append(nn.BatchNorm1d(out_feat, 0.8))
@@ -135,10 +128,6 @@ def __init__(
135128

136129
self.example_input_array = torch.zeros(2, self.hparams.latent_dim)
137130

138-
# ! TESTING
139-
self.save_path = "pl_test_multi_gpu"
140-
os.makedirs(self.save_path, exist_ok=True)
141-
142131
def forward(self, z):
143132
return self.generator(z)
144133

@@ -203,36 +192,25 @@ def configure_optimizers(self):
203192
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
204193
return opt_g, opt_d
205194

206-
# ! TESTING
207-
def on_train_epoch_start(self):
208-
if self.trainer.is_global_zero:
209-
print("GEN: ", self.generator.module.model[0].bias[:10])
210-
print("DISC: ", self.discriminator.module.model[0].bias[:10])
211-
212-
# ! TESTING
213-
def validation_step(self, batch, batch_idx):
214-
pass
195+
def on_train_epoch_end(self):
196+
z = self.validation_z.type_as(self.generator.model[0].weight)
215197

216-
# ! TESTING
217-
@torch.no_grad()
218-
def on_validation_epoch_end(self):
219-
if not self.current_epoch % 5:
220-
return
221-
self.generator.eval(), self.discriminator.eval()
222-
223-
z = self.validation_z.type_as(self.generator.module.model[0].weight)
198+
# log sampled images`
224199
sample_imgs = self(z)
225-
226-
if self.trainer.is_global_zero:
227-
grid = torchvision.utils.make_grid(sample_imgs)
228-
torchvision.utils.save_image(grid, os.path.join(self.save_path, f"epoch_{self.current_epoch}.png"))
229-
230-
self.generator.train(), self.discriminator.train()
200+
grid = torchvision.utils.make_grid(sample_imgs)
201+
for logger in self.loggers:
202+
logger.experiment.add_image("generated_images", grid, self.current_epoch)
231203

232204

233205
def main(args: Namespace) -> None:
206+
# ------------------------
207+
# 1 INIT LIGHTNING MODEL
208+
# ------------------------
234209
model = GAN(lr=args.lr, b1=args.b1, b2=args.b2, latent_dim=args.latent_dim)
235210

211+
# ------------------------
212+
# 2 INIT TRAINER
213+
# ------------------------
236214
# ! `MultiModelDDPStrategy` is critical for multi-gpu training
237215
# ! Otherwise, it will not work with multiple models.
238216
# ! There are two ways to run training codes with previous `DDPStrategy`;
@@ -246,6 +224,9 @@ def main(args: Namespace) -> None:
246224
max_epochs=100,
247225
)
248226

227+
# ------------------------
228+
# 3 START TRAINING
229+
# ------------------------
249230
trainer.fit(model, dm)
250231

251232

src/lightning/pytorch/strategies/ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def teardown(self) -> None:
421421

422422
class MultiModelDDPStrategy(DDPStrategy):
423423
@override
424-
def _setup_model(self, model: Module) -> Module:
424+
def _setup_model(self, model: Module) -> DistributedDataParallel:
425425
device_ids = self.determine_ddp_device_ids()
426426
log.debug(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
427427
# https://pytorch.org/docs/stable/notes/cuda.html#id5

0 commit comments

Comments
 (0)