33from torch import optim
44from models import BaseVAE
55from models .types_ import *
6+ from utils import data_loader
67import pytorch_lightning as pl
78from torchvision import transforms
89import torchvision .utils as vutils
910from torchvision .datasets import CelebA
1011from torch .utils .data import DataLoader
1112
1213
13-
1414class VAEXperiment (pl .LightningModule ):
1515
1616 def __init__ (self ,
@@ -50,7 +50,7 @@ def validation_step(self, batch, batch_idx, optimizer_idx = 0):
5050
5151 results = self .forward (real_img , labels = labels )
5252 val_loss = self .model .loss_function (* results ,
53- M_N = self .params ['batch_size' ]/ self .num_train_imgs ,
53+ M_N = self .params ['batch_size' ]/ self .num_val_imgs ,
5454 optimizer_idx = optimizer_idx ,
5555 batch_idx = batch_idx )
5656
@@ -132,7 +132,7 @@ def configure_optimizers(self):
132132 except :
133133 return optims
134134
135- @pl . data_loader
135+ @data_loader
136136 def train_dataloader (self ):
137137 transform = self .data_transforms ()
138138
@@ -150,7 +150,7 @@ def train_dataloader(self):
150150 shuffle = True ,
151151 drop_last = True )
152152
153- @pl . data_loader
153+ @data_loader
154154 def val_dataloader (self ):
155155 transform = self .data_transforms ()
156156
@@ -162,8 +162,10 @@ def val_dataloader(self):
162162 batch_size = 144 ,
163163 shuffle = True ,
164164 drop_last = True )
165+ self .num_val_imgs = len (self .sample_dataloader )
165166 else :
166167 raise ValueError ('Undefined dataset type' )
168+
167169 return self .sample_dataloader
168170
169171 def data_transforms (self ):
0 commit comments