2020"""
2121
2222import math
23-
24- # ! TESTING
25- import os
26- import sys
2723from argparse import ArgumentParser , Namespace
2824
2925import torch
3026import torch .nn as nn
3127import torch .nn .functional as F
3228
33- sys .path .append (os .path .join (os .getcwd (), "src" ))
34- # ! TESTING
35-
3629from lightning .pytorch import cli_lightning_logo
3730from lightning .pytorch .core import LightningModule
3831from lightning .pytorch .demos .mnist_datamodule import MNISTDataModule
4437 import torchvision
4538
4639
47- def _block (in_feat : int , out_feat : int , normalize : bool = True ):
40+ def _block (in_feat : int , out_feat : int , normalize : bool = True ) -> list :
4841 layers = [nn .Linear (in_feat , out_feat )]
4942 if normalize :
5043 layers .append (nn .BatchNorm1d (out_feat , 0.8 ))
@@ -135,10 +128,6 @@ def __init__(
135128
136129 self .example_input_array = torch .zeros (2 , self .hparams .latent_dim )
137130
138- # ! TESTING
139- self .save_path = "pl_test_multi_gpu"
140- os .makedirs (self .save_path , exist_ok = True )
141-
142131 def forward (self , z ):
143132 return self .generator (z )
144133
@@ -203,36 +192,25 @@ def configure_optimizers(self):
203192 opt_d = torch .optim .Adam (self .discriminator .parameters (), lr = lr , betas = (b1 , b2 ))
204193 return opt_g , opt_d
205194
206- # ! TESTING
207- def on_train_epoch_start (self ):
208- if self .trainer .is_global_zero :
209- print ("GEN: " , self .generator .module .model [0 ].bias [:10 ])
210- print ("DISC: " , self .discriminator .module .model [0 ].bias [:10 ])
211-
212- # ! TESTING
213- def validation_step (self , batch , batch_idx ):
214- pass
195+ def on_train_epoch_end (self ):
196+ z = self .validation_z .type_as (self .generator .model [0 ].weight )
215197
216- # ! TESTING
217- @torch .no_grad ()
218- def on_validation_epoch_end (self ):
219- if not self .current_epoch % 5 :
220- return
221- self .generator .eval (), self .discriminator .eval ()
222-
223- z = self .validation_z .type_as (self .generator .module .model [0 ].weight )
198+ # log sampled images`
224199 sample_imgs = self (z )
225-
226- if self .trainer .is_global_zero :
227- grid = torchvision .utils .make_grid (sample_imgs )
228- torchvision .utils .save_image (grid , os .path .join (self .save_path , f"epoch_{ self .current_epoch } .png" ))
229-
230- self .generator .train (), self .discriminator .train ()
200+ grid = torchvision .utils .make_grid (sample_imgs )
201+ for logger in self .loggers :
202+ logger .experiment .add_image ("generated_images" , grid , self .current_epoch )
231203
232204
233205def main (args : Namespace ) -> None :
206+ # ------------------------
207+ # 1 INIT LIGHTNING MODEL
208+ # ------------------------
234209 model = GAN (lr = args .lr , b1 = args .b1 , b2 = args .b2 , latent_dim = args .latent_dim )
235210
211+ # ------------------------
212+ # 2 INIT TRAINER
213+ # ------------------------
236214 # ! `MultiModelDDPStrategy` is critical for multi-gpu training
237215 # ! Otherwise, it will not work with multiple models.
238216 # ! There are two ways to run training codes with previous `DDPStrategy`;
@@ -246,6 +224,9 @@ def main(args: Namespace) -> None:
246224 max_epochs = 100 ,
247225 )
248226
227+ # ------------------------
228+ # 3 START TRAINING
229+ # ------------------------
249230 trainer .fit (model , dm )
250231
251232
0 commit comments