Skip to content

Commit 90172be

Browse files
committed
Fix for data_loader in pl version 0.7
1 parent 6d6a3f3 commit 90172be

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

experiment.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from torch import optim
44
from models import BaseVAE
55
from models.types_ import *
6+
from utils import data_loader
67
import pytorch_lightning as pl
78
from torchvision import transforms
89
import torchvision.utils as vutils
910
from torchvision.datasets import CelebA
1011
from torch.utils.data import DataLoader
1112

1213

13-
1414
class VAEXperiment(pl.LightningModule):
1515

1616
def __init__(self,
@@ -50,7 +50,7 @@ def validation_step(self, batch, batch_idx, optimizer_idx = 0):
5050

5151
results = self.forward(real_img, labels = labels)
5252
val_loss = self.model.loss_function(*results,
53-
M_N = self.params['batch_size']/ self.num_train_imgs,
53+
M_N = self.params['batch_size']/ self.num_val_imgs,
5454
optimizer_idx = optimizer_idx,
5555
batch_idx = batch_idx)
5656

@@ -132,7 +132,7 @@ def configure_optimizers(self):
132132
except:
133133
return optims
134134

135-
@pl.data_loader
135+
@data_loader
136136
def train_dataloader(self):
137137
transform = self.data_transforms()
138138

@@ -150,7 +150,7 @@ def train_dataloader(self):
150150
shuffle = True,
151151
drop_last=True)
152152

153-
@pl.data_loader
153+
@data_loader
154154
def val_dataloader(self):
155155
transform = self.data_transforms()
156156

@@ -162,8 +162,10 @@ def val_dataloader(self):
162162
batch_size= 144,
163163
shuffle = True,
164164
drop_last=True)
165+
self.num_val_imgs = len(self.sample_dataloader)
165166
else:
166167
raise ValueError('Undefined dataset type')
168+
167169
return self.sample_dataloader
168170

169171
def data_transforms(self):

utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytorch_lightning as pl
2+
3+
4+
## Utils to handle newer PyTorch Lightning changes from version 0.6
5+
## ==================================================================================================== ##
6+
7+
8+
def data_loader(fn):
9+
"""
10+
Decorator to handle the deprecation of data_loader from 0.7
11+
:param fn: User defined data loader function
12+
:return: A wrapper for the data_loader function
13+
"""
14+
15+
def func_wrapper(self):
16+
try: # Works for version 0.6.0
17+
return pl.data_loader(fn)(self)
18+
19+
except: # Works for version > 0.6.0
20+
return fn(self)
21+
22+
return func_wrapper

0 commit comments

Comments
 (0)