Skip to content

Commit c993d0c

Browse files
authored
Make unimplemented dataloader hooks raise NotImplementedError (#9161)
1 parent 3fd77cb commit c993d0c

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
143143
- `Trainer.request_dataloader` now takes a `RunningStage` enum instance ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))
144144

145145

146+
- Changed `rank_zero_warn` to `NotImplementedError` in the `{train, val, test, predict}_dataloader` hooks that `Lightning(Data)Module` uses ([#9161](https://github.com/PyTorchLightning/pytorch-lightning/pull/9161))
147+
146148
### Deprecated
147149

148150
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`

pytorch_lightning/core/hooks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
from torch.optim.optimizer import Optimizer
2020

21-
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
21+
from pytorch_lightning.utilities import move_data_to_device
2222
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, STEP_OUTPUT, TRAIN_DATALOADERS
2323

2424

@@ -540,7 +540,7 @@ def train_dataloader(self):
540540
return {'mnist': mnist_loader, 'cifar': cifar_loader}
541541
542542
"""
543-
rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer")
543+
raise NotImplementedError("`train_dataloader` must be implemented to be used with the Lightning Trainer")
544544

545545
def test_dataloader(self) -> EVAL_DATALOADERS:
546546
r"""
@@ -602,6 +602,7 @@ def test_dataloader(self):
602602
In the case where you return multiple test dataloaders, the :meth:`test_step`
603603
will have an argument ``dataloader_idx`` which matches the order here.
604604
"""
605+
raise NotImplementedError("`test_dataloader` must be implemented to be used with the Lightning Trainer")
605606

606607
def val_dataloader(self) -> EVAL_DATALOADERS:
607608
r"""
@@ -654,6 +655,7 @@ def val_dataloader(self):
654655
In the case where you return multiple validation dataloaders, the :meth:`validation_step`
655656
will have an argument ``dataloader_idx`` which matches the order here.
656657
"""
658+
raise NotImplementedError("`val_dataloader` must be implemented to be used with the Lightning Trainer")
657659

658660
def predict_dataloader(self) -> EVAL_DATALOADERS:
659661
r"""
@@ -679,6 +681,7 @@ def predict_dataloader(self) -> EVAL_DATALOADERS:
679681
In the case where you return multiple prediction dataloaders, the :meth:`predict`
680682
will have an argument ``dataloader_idx`` which matches the order here.
681683
"""
684+
raise NotImplementedError("`predict_dataloader` must be implemented to be used with the Lightning Trainer")
682685

683686
def on_train_dataloader(self) -> None:
684687
"""Called before requesting the train dataloader."""

tests/core/test_datamodules.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,10 @@ def test_dm_init_from_datasets_dataloaders(iterable):
480480
with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
481481
dm.train_dataloader()
482482
dl_mock.assert_called_once_with(train_ds, batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True)
483-
assert dm.val_dataloader() is None
484-
assert dm.test_dataloader() is None
483+
with pytest.raises(NotImplementedError):
484+
_ = dm.val_dataloader()
485+
with pytest.raises(NotImplementedError):
486+
_ = dm.test_dataloader()
485487

486488
train_ds_sequence = [ds(), ds()]
487489
dm = LightningDataModule.from_datasets(train_ds_sequence, batch_size=4, num_workers=0)
@@ -493,8 +495,10 @@ def test_dm_init_from_datasets_dataloaders(iterable):
493495
call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True),
494496
]
495497
)
496-
assert dm.val_dataloader() is None
497-
assert dm.test_dataloader() is None
498+
with pytest.raises(NotImplementedError):
499+
_ = dm.val_dataloader()
500+
with pytest.raises(NotImplementedError):
501+
_ = dm.test_dataloader()
498502

499503
valid_ds = ds()
500504
test_ds = ds()
@@ -504,7 +508,8 @@ def test_dm_init_from_datasets_dataloaders(iterable):
504508
dl_mock.assert_called_with(valid_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
505509
dm.test_dataloader()
506510
dl_mock.assert_called_with(test_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
507-
assert dm.train_dataloader() is None
511+
with pytest.raises(NotImplementedError):
512+
_ = dm.train_dataloader()
508513

509514
valid_dss = [ds(), ds()]
510515
test_dss = [ds(), ds()]

0 commit comments

Comments
 (0)