Skip to content

Commit e5805bf

Browse files
val and test are optional now (#95)
* made validation step optional * added no val model * val_step can be implemented but not validation_end * added no val end model * added tests * added tests * remove class * remove class * remove class * remove class * remove class * remove class * remove class * remove class * remove class * remove class * remove class * updated docs * updated docs * updated test * updated test * updated test * updated test * updated test * updated test * updated test * updated test * updated test * fix pep8
1 parent 996b1f9 commit e5805bf

File tree

8 files changed

+609
-53
lines changed

8 files changed

+609
-53
lines changed

README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,36 +81,40 @@ class CoolModel(pl.LightningModule):
8181
def forward(self, x):
8282
return torch.relu(self.l1(x.view(x.size(0), -1)))
8383

84-
def my_loss(self, y_hat, y):
85-
return F.cross_entropy(y_hat, y)
86-
8784
def training_step(self, batch, batch_nb):
85+
# REQUIRED
8886
x, y = batch
8987
y_hat = self.forward(x)
90-
return {'loss': self.my_loss(y_hat, y)}
88+
return {'loss': F.cross_entropy(y_hat, y)}
9189

9290
def validation_step(self, batch, batch_nb):
91+
# OPTIONAL
9392
x, y = batch
9493
y_hat = self.forward(x)
9594
return {'val_loss': self.my_loss(y_hat, y)}
9695

9796
def validation_end(self, outputs):
97+
# OPTIONAL
9898
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
9999
return {'avg_val_loss': avg_loss}
100100

101101
def configure_optimizers(self):
102+
# REQUIRED
102103
return [torch.optim.Adam(self.parameters(), lr=0.02)]
103104

104105
@pl.data_loader
105106
def tng_dataloader(self):
107+
# REQUIRED
106108
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
107109

108110
@pl.data_loader
109111
def val_dataloader(self):
112+
# OPTIONAL
110113
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
111114

112115
@pl.data_loader
113116
def test_dataloader(self):
117+
# OPTIONAL
114118
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
115119
```
116120

docs/LightningModule/RequiredTrainerInterface.md

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,14 @@ Otherwise, to Define a Lightning Module, implement the following methods:
1010
**Required**:
1111

1212
- [training_step](RequiredTrainerInterface.md#training_step)
13-
- [validation_step](RequiredTrainerInterface.md#validation_step)
14-
- [validation_end](RequiredTrainerInterface.md#validation_end)
15-
16-
- [configure_optimizers](RequiredTrainerInterface.md#configure_optimizers)
17-
18-
- [tng_dataloader](RequiredTrainerInterface.md#tng_dataloader)
1913
- [tng_dataloader](RequiredTrainerInterface.md#tng_dataloader)
20-
- [test_dataloader](RequiredTrainerInterface.md#test_dataloader)
14+
- [configure_optimizers](RequiredTrainerInterface.md#configure_optimizers)
2115

2216
**Optional**:
17+
- [validation_step](RequiredTrainerInterface.md#validation_step)
18+
- [validation_end](RequiredTrainerInterface.md#validation_end)
19+
- [val_dataloader](RequiredTrainerInterface.md#val_dataloader)
20+
- [test_dataloader](RequiredTrainerInterface.md#test_dataloader)
2321

2422
- [on_save_checkpoint](RequiredTrainerInterface.md#on_save_checkpoint)
2523
- [on_load_checkpoint](RequiredTrainerInterface.md#on_load_checkpoint)

pytorch_lightning/models/trainer.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.multiprocessing as mp
1414
import torch.distributed as dist
1515

16+
from pytorch_lightning.root_module.root_module import LightningModule
1617
from pytorch_lightning.root_module.memory import get_gpu_memory_map
1718
from pytorch_lightning.root_module.model_saving import TrainerIO
1819
from pytorch_lightning.pt_overrides.override_data_parallel import (
@@ -312,6 +313,14 @@ def __is_function_implemented(self, f_name):
312313
f_op = getattr(model, f_name, None)
313314
return callable(f_op)
314315

316+
def __is_overriden(self, f_name):
317+
model = self.__get_model()
318+
super_object = super(model.__class__, model)
319+
320+
# when code pointers are different, it was overriden
321+
is_overriden = getattr(model, f_name).__code__ is not getattr(super_object, f_name).__code__
322+
return is_overriden
323+
315324
@property
316325
def __tng_tqdm_dic(self):
317326
tqdm_dic = {
@@ -345,13 +354,13 @@ def __layout_bookeeping(self):
345354
self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check)
346355

347356
# determine number of validation batches
348-
self.nb_val_batches = len(self.val_dataloader)
357+
self.nb_val_batches = len(self.val_dataloader) if self.val_dataloader is not None else 0
349358
self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check)
350359
self.nb_val_batches = max(1, self.nb_val_batches)
351360
self.nb_val_batches = self.nb_val_batches
352361

353362
# determine number of test batches
354-
self.nb_test_batches = len(self.test_dataloader)
363+
self.nb_test_batches = len(self.test_dataloader) if self.test_dataloader is not None else 0
355364
self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check)
356365

357366
# determine when to check validation
@@ -372,6 +381,10 @@ def validate(self, model, dataloader, max_batches):
372381
:param max_batches: Scalar
373382
:return:
374383
"""
384+
# skip validation if model has no validation_step defined
385+
if not self.__is_overriden('validation_step'):
386+
return {}
387+
375388
# enable eval mode
376389
model.zero_grad()
377390
model.eval()
@@ -418,11 +431,13 @@ def validate(self, model, dataloader, max_batches):
418431
if self.progress_bar and self.prog_bar is not None:
419432
self.prog_bar.update(1)
420433

421-
# give model a chance to do something with the outputs
422-
if self.data_parallel:
423-
val_results = model.module.validation_end(outputs)
424-
else:
425-
val_results = model.validation_end(outputs)
434+
# give model a chance to do something with the outputs (and method defined)
435+
val_results = {}
436+
if self.__is_overriden('validation_end'):
437+
if self.data_parallel:
438+
val_results = model.module.validation_end(outputs)
439+
else:
440+
val_results = model.validation_end(outputs)
426441

427442
# enable train mode again
428443
model.train()
@@ -439,6 +454,7 @@ def get_dataloaders(self, model):
439454
:return:
440455
"""
441456
self.tng_dataloader = model.tng_dataloader
457+
442458
self.test_dataloader = model.test_dataloader
443459
self.val_dataloader = model.val_dataloader
444460

pytorch_lightning/root_module/root_module.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,20 @@ def forward(self, *args, **kwargs):
3636
def validation_step(self, data_batch, batch_nb):
3737
"""
3838
return whatever outputs will need to be aggregated in validation_end
39+
OPTIONAL
3940
:param data_batch:
4041
:return:
4142
"""
42-
raise NotImplementedError
43+
pass
4344

4445
def validation_end(self, outputs):
4546
"""
4647
Outputs has the appended output after each validation step
48+
OPTIONAL
4749
:param outputs:
4850
:return: dic_with_metrics for tqdm
4951
"""
50-
raise NotImplementedError
52+
pass
5153

5254
def training_step(self, data_batch, batch_nb):
5355
"""
@@ -67,26 +69,26 @@ def configure_optimizers(self):
6769
@data_loader
6870
def tng_dataloader(self):
6971
"""
70-
Implement a function to load an h5py of this data
72+
Implement a PyTorch DataLoader
7173
:return:
7274
"""
7375
raise NotImplementedError
7476

7577
@data_loader
7678
def test_dataloader(self):
7779
"""
78-
Implement a function to load an h5py of this data
80+
Implement a PyTorch DataLoader
7981
:return:
8082
"""
81-
raise NotImplementedError
83+
return None
8284

8385
@data_loader
8486
def val_dataloader(self):
8587
"""
86-
Implement a function to load an h5py of this data
88+
Implement a PyTorch DataLoader
8789
:return:
8890
"""
89-
raise NotImplementedError
91+
return None
9092

9193
@classmethod
9294
def load_from_metrics(cls, weights_path, tags_csv, on_gpu, map_location=None):
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .lm_test_module import LightningTestModel
2+
from .no_val_end_module import NoValEndTestModel
3+
from .no_val_module import NoValModel

0 commit comments

Comments
 (0)