Skip to content

Commit 511f7ec

Browse files
sholalkerewilliamFalcon
authored andcommitted
Support for multiple val_dataloaders (#97)
* Added support for multiple validation dataloaders * Fix typo in README.md * Update trainer.py * Add support for multiple dataloaders * Rename dataloader_index to dataloader_i * Added warning to check val_dataloaders Added a warning to ensure that all val_dataloaders were DistributedSamplers if ddp is enabled * Updated DistributedSampler warning * Fixed typo * Added multiple val_dataloaders * Multiple val_dataloader test * Update lightning_module_template.py Added dataloader_i to validation_step parameters * Update trainer.py * Reverted template changes * Create multi_val_module.py * Update no_val_end_module.py * New MultiValModel * Rename MultiValModel to MultiValTestModel * Revert to LightningTestModel * Update test_models.py * Update trainer.py * Update test_models.py * multiple val_dataloaders in test template * Fixed flake8 warnings * Update trainer.py * Fix flake errors * Fixed Flake8 errors * Update lm_test_module.py keep this test model with a single dataset for val * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update test_models.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update RequiredTrainerInterface.md * Update RequiredTrainerInterface.md * Update test_models.py * Update trainer.py dont need the else clause, val_dataloader is either a list or none because of get_dataloaders() * Update trainer.py fixed flake errors * Update trainer.py
1 parent 46e27e3 commit 511f7ec

File tree

6 files changed

+197
-108
lines changed

6 files changed

+197
-108
lines changed

docs/LightningModule/RequiredTrainerInterface.md

Lines changed: 92 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,25 @@ class CoolModel(pl.LightningModule):
4646
def forward(self, x):
4747
return torch.relu(self.l1(x.view(x.size(0), -1)))
4848

49-
def my_loss(self, y_hat, y):
50-
return F.cross_entropy(y_hat, y)
51-
5249
def training_step(self, batch, batch_nb):
50+
# REQUIRED
5351
x, y = batch
5452
y_hat = self.forward(x)
55-
return {'loss': self.my_loss(y_hat, y)}
53+
return {'loss': F.cross_entropy(y_hat, y)(y_hat, y)}
5654

5755
def validation_step(self, batch, batch_nb):
56+
# OPTIONAL
5857
x, y = batch
5958
y_hat = self.forward(x)
60-
return {'val_loss': self.my_loss(y_hat, y)}
59+
return {'val_loss': F.cross_entropy(y_hat, y)(y_hat, y)}
6160

6261
def validation_end(self, outputs):
62+
# OPTIONAL
6363
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
6464
return {'avg_val_loss': avg_loss}
6565

6666
def configure_optimizers(self):
67+
# REQUIRED
6768
return [torch.optim.Adam(self.parameters(), lr=0.02)]
6869

6970
@pl.data_loader
@@ -72,10 +73,13 @@ class CoolModel(pl.LightningModule):
7273

7374
@pl.data_loader
7475
def val_dataloader(self):
76+
# OPTIONAL
77+
# can also return a list of val dataloaders
7578
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
7679

7780
@pl.data_loader
7881
def test_dataloader(self):
82+
# OPTIONAL
7983
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
8084
```
8185
---
@@ -88,7 +92,7 @@ The LightningModule interface is on the right. Each method corresponds to a part
8892
</a>
8993
</p>
9094

91-
---
95+
## Required Methods
9296

9397
### training_step
9498

@@ -134,15 +138,75 @@ def training_step(self, data_batch, batch_nb):
134138
return output
135139
```
136140

137-
---
141+
---
142+
### tng_dataloader
143+
144+
``` {.python}
145+
@pl.data_loader
146+
def tng_dataloader(self)
147+
```
148+
Called by lightning during training loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
149+
150+
##### Return
151+
PyTorch DataLoader
152+
153+
**Example**
154+
155+
``` {.python}
156+
@pl.data_loader
157+
def tng_dataloader(self):
158+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
159+
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True)
160+
loader = torch.utils.data.DataLoader(
161+
dataset=dataset,
162+
batch_size=self.hparams.batch_size,
163+
shuffle=True
164+
)
165+
return loader
166+
```
167+
168+
---
169+
### configure_optimizers
170+
171+
``` {.python}
172+
def configure_optimizers(self)
173+
```
174+
175+
Set up as many optimizers and (optionally) learning rate schedulers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple.
176+
Lightning will call .backward() and .step() on each one in every epoch. If you use 16 bit precision it will also handle that.
177+
178+
179+
##### Return
180+
List or Tuple - List of optimizers with an optional second list of learning-rate schedulers
181+
182+
**Example**
183+
184+
``` {.python}
185+
# most cases
186+
def configure_optimizers(self):
187+
opt = Adam(self.parameters(), lr=0.01)
188+
return [opt]
189+
190+
# gan example, with scheduler for discriminator
191+
def configure_optimizers(self):
192+
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
193+
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
194+
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
195+
return [generator_opt, disriminator_opt], [discriminator_sched]
196+
```
197+
198+
## Optional Methods
138199

139200
### validation_step
140201

141202
``` {.python}
142-
def validation_step(self, data_batch, batch_nb)
203+
def validation_step(self, data_batch, batch_nb, dataloader_i)
143204
```
205+
**OPTIONAL**
206+
If you don't need to validate you don't need to implement this method.
207+
208+
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes, calculate accuracy, or save example outputs (using self.experiment or whatever you want). Really, anything you want.
144209

145-
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something specific to your model.
146210
This is most likely the same as your training_step. But unlike training step, the outputs from here will go to validation_end for collation.
147211

148212
**Params**
@@ -151,6 +215,7 @@ This is most likely the same as your training_step. But unlike training step, th
151215
|---|---|
152216
| data_batch | The output of your dataloader. A tensor, tuple or list |
153217
| batch_nb | Integer displaying which batch this is |
218+
| dataloader_i | Integer displaying which dataloader this is |
154219

155220
**Return**
156221

@@ -188,9 +253,12 @@ def validation_step(self, data_batch, batch_nb):
188253

189254
``` {.python}
190255
def validation_end(self, outputs)
191-
```
256+
```
257+
If you didn't define a validation_step, this won't be called.
258+
259+
Called at the end of the validation loop with the output of each validation_step. Called once per validation dataset.
192260

193-
Called at the end of the validation loop with the output of each validation_step.
261+
The outputs here are strictly for the progress bar. If you don't need to display anything, don't return anything.
194262

195263
**Params**
196264

@@ -225,36 +293,6 @@ def validation_end(self, outputs):
225293
return tqdm_dic
226294
```
227295

228-
---
229-
### configure_optimizers
230-
231-
``` {.python}
232-
def configure_optimizers(self)
233-
```
234-
235-
Set up as many optimizers and (optionally) learning rate schedulers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple.
236-
Lightning will call .backward() and .step() on each one in every epoch. If you use 16 bit precision it will also handle that.
237-
238-
239-
##### Return
240-
List or Tuple - List of optimizers with an optional second list of learning-rate schedulers
241-
242-
**Example**
243-
244-
``` {.python}
245-
# most cases
246-
def configure_optimizers(self):
247-
opt = Adam(self.parameters(), lr=0.01)
248-
return [opt]
249-
250-
# gan example, with scheduler for discriminator
251-
def configure_optimizers(self):
252-
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
253-
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
254-
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
255-
return [generator_opt, disriminator_opt], [discriminator_sched]
256-
```
257-
258296
---
259297
### on_save_checkpoint
260298

@@ -297,44 +335,20 @@ def on_load_checkpoint(self, checkpoint):
297335
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
298336
```
299337

300-
---
301-
### tng_dataloader
302-
303-
``` {.python}
304-
@pl.data_loader
305-
def tng_dataloader(self)
306-
```
307-
Called by lightning during training loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
308-
309-
##### Return
310-
PyTorch DataLoader
311-
312-
**Example**
313-
314-
``` {.python}
315-
@pl.data_loader
316-
def tng_dataloader(self):
317-
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
318-
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True)
319-
loader = torch.utils.data.DataLoader(
320-
dataset=dataset,
321-
batch_size=self.hparams.batch_size,
322-
shuffle=True
323-
)
324-
return loader
325-
```
326-
327338
---
328339
### val_dataloader
329340

330341
``` {.python}
331342
@pl.data_loader
332343
def tng_dataloader(self)
333344
```
334-
Called by lightning during validation loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
345+
**OPTIONAL**
346+
If you don't need a validation dataset and a validation_step, you don't need to implement this method.
347+
348+
Called by lightning during validation loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
335349

336350
##### Return
337-
PyTorch DataLoader
351+
PyTorch DataLoader or list of PyTorch Dataloaders.
338352

339353
**Example**
340354

@@ -350,6 +364,11 @@ def val_dataloader(self):
350364
)
351365
352366
return loader
367+
368+
# can also return multiple dataloaders
369+
@pl.data_loader
370+
def val_dataloader(self):
371+
return [loader_a, loader_b, ..., loader_n]
353372
```
354373

355374
---
@@ -359,6 +378,9 @@ def val_dataloader(self):
359378
@pl.data_loader
360379
def test_dataloader(self)
361380
```
381+
**OPTIONAL**
382+
If you don't need a test dataset and a test_step, you don't need to implement this method.
383+
362384
Called by lightning during test loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
363385

364386
##### Return

examples/new_project_templates/lightning_module_template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def training_step(self, data_batch, batch_i):
105105
# can also return just a scalar instead of a dict (return loss_val)
106106
return output
107107

108-
def validation_step(self, data_batch, batch_i):
108+
def validation_step(self, data_batch, batch_i, dataloader_i):
109109
"""
110110
Lightning calls this inside the validation loop
111111
:param data_batch:
@@ -218,7 +218,7 @@ def tng_dataloader(self):
218218
@pl.data_loader
219219
def val_dataloader(self):
220220
print('val data loader called')
221-
return self.__dataloader(train=False)
221+
return [self.__dataloader(train=False) for i in range(2)]
222222

223223
@pl.data_loader
224224
def test_dataloader(self):

0 commit comments

Comments
 (0)