Skip to content

Commit 582b8cc

Browse files
awaelchlilexierule
authored andcommitted
Better error message when dataloader and datamodule is None (V2) (#14637)
1 parent 72f82eb commit 582b8cc

File tree

5 files changed

+67
-45
lines changed

5 files changed

+67
-45
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77
## [1.7.6] - 2022-09-13
88

9+
### Changed
10+
11+
- When using multiple loggers, by default checkpoints and profiler output now get saved to the log dir of the first logger in the list ([#14325](https://github.com/Lightning-AI/lightning/pull/14325))
12+
- Improved the error messaging when passing `Trainer.method(model, x_dataloader=None)` with no module-method implementations available ([#14614](https://github.com/Lightning-AI/lightning/pull/14614))
13+
914
### Fixed
1015

1116
- Reset the dataloaders on OOM failure in batch size finder to use the last successful batch size ([#14372](https://github.com/Lightning-AI/lightning/pull/14372))

src/pytorch_lightning/trainer/configuration_validator.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,6 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh
7171
" `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
7272
)
7373

74-
# -----------------------------------
75-
# verify model has a train dataloader
76-
# -----------------------------------
77-
has_train_dataloader = trainer._data_connector._train_dataloader_source.is_defined()
78-
if not has_train_dataloader:
79-
raise MisconfigurationException(
80-
"No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a"
81-
" `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined."
82-
)
83-
8474
# -----------------------------------
8575
# verify model has optimizer
8676
# -----------------------------------
@@ -121,19 +111,11 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh
121111

122112

123113
def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule", stage: str) -> None:
124-
loader_name = f"{stage}_dataloader"
125114
step_name = "validation_step" if stage == "val" else f"{stage}_step"
126115
trainer_method = "validate" if stage == "val" else stage
127116

128-
has_loader = getattr(trainer._data_connector, f"_{stage}_dataloader_source").is_defined()
129117
has_step = is_overridden(step_name, model)
130118

131-
# -----------------------------------
132-
# verify model has an eval_dataloader
133-
# -----------------------------------
134-
if not has_loader:
135-
raise MisconfigurationException(f"No `{loader_name}()` method defined to run `Trainer.{trainer_method}`.")
136-
137119
# predict_step is not required to be overridden
138120
if stage == "predict":
139121
if model.predict_step is None:

src/pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,17 @@ def attach_data(
144144
predict_dataloaders=predict_dataloaders,
145145
)
146146
self.attach_datamodule(model, datamodule=datamodule)
147+
148+
# Validate that the required data sources are available
149+
if self.trainer.state.fn == TrainerFn.FITTING:
150+
_check_dataloader_none(train_dataloaders, self._train_dataloader_source, self.trainer.state.fn)
151+
elif self.trainer.state.fn == TrainerFn.VALIDATING:
152+
_check_dataloader_none(val_dataloaders, self._val_dataloader_source, self.trainer.state.fn)
153+
elif self.trainer.state.fn == TrainerFn.TESTING:
154+
_check_dataloader_none(test_dataloaders, self._test_dataloader_source, self.trainer.state.fn)
155+
elif self.trainer.state.fn == TrainerFn.PREDICTING:
156+
_check_dataloader_none(predict_dataloaders, self._predict_dataloader_source, self.trainer.state.fn)
157+
147158
# set local properties on the model
148159
self._copy_trainer_model_properties(model)
149160

@@ -581,3 +592,18 @@ def get_instance(self, hook_name: str) -> Union["pl.LightningModule", "pl.Lightn
581592
" `LightningDataModule`. It will use the implementation from `LightningModule` instance."
582593
)
583594
return self.model
595+
596+
597+
def _check_dataloader_none(
598+
dataloader: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]],
599+
dataloader_source: _DataLoaderSource,
600+
trainer_fn: TrainerFn,
601+
) -> None:
602+
# A prefix in the message to disambiguate between the train- and (optional) val dataloader that .fit() accepts
603+
prefix = "train_" if trainer_fn == TrainerFn.FITTING else ""
604+
if dataloader is None and not dataloader_source.is_defined():
605+
raise ValueError(
606+
f"An invalid dataloader was passed to `Trainer.{trainer_fn}({prefix}dataloaders=...)`."
607+
f" Either pass the dataloader to the `.{trainer_fn}()` method OR implement"
608+
f" `def {dataloader_source.name}(self):` in your LightningModule/LightningDataModule."
609+
)

tests/tests_pytorch/trainer/connectors/test_data_connector.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,3 +570,38 @@ def test_error_raised_with_insufficient_float_limit_train_dataloader():
570570
match="Please increase the `limit_train_batches` argument. Try at least",
571571
):
572572
trainer.reset_train_dataloader(model)
573+
574+
575+
@pytest.mark.parametrize(
576+
"trainer_fn_name, dataloader_name",
577+
[
578+
("fit", "train_dataloaders"),
579+
("validate", "dataloaders"),
580+
("test", "dataloaders"),
581+
("predict", "dataloaders"),
582+
],
583+
)
584+
def test_attach_data_input_validation_with_none_dataloader(trainer_fn_name, dataloader_name, tmpdir):
585+
"""Test that passing `Trainer.method(x_dataloader=None)` with no module-method implementations available raises
586+
an error."""
587+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
588+
model = BoringModel()
589+
datamodule = BoringDataModule()
590+
trainer_fn = getattr(trainer, trainer_fn_name)
591+
592+
# Pretend that these methods are not implemented
593+
model.train_dataloader = None
594+
model.val_dataloader = None
595+
model.test_dataloader = None
596+
model.predict_dataloader = None
597+
598+
datamodule.train_dataloader = None
599+
datamodule.val_dataloader = None
600+
datamodule.test_dataloader = None
601+
datamodule.predict_dataloader = None
602+
603+
with pytest.raises(ValueError, match=f"An invalid .*dataloader was passed to `Trainer.{trainer_fn_name}"):
604+
trainer_fn(model, **{dataloader_name: None}, datamodule=datamodule)
605+
606+
with pytest.raises(ValueError, match=f"An invalid .*dataloader was passed to `Trainer.{trainer_fn_name}"):
607+
trainer_fn(model, **{dataloader_name: None}, datamodule=None)

tests/tests_pytorch/trainer/test_config_validator.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,9 @@
2222

2323

2424
def test_wrong_train_setting(tmpdir):
25-
"""
26-
* Test that an error is thrown when no `train_dataloader()` is defined
27-
* Test that an error is thrown when no `training_step()` is defined
28-
"""
25+
"""Test that an error is raised when no `training_step()` is defined."""
2926
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
3027

31-
with pytest.raises(MisconfigurationException, match=r"No `train_dataloader\(\)` method defined."):
32-
model = BoringModel()
33-
model.train_dataloader = None
34-
trainer.fit(model)
35-
3628
with pytest.raises(MisconfigurationException, match=r"No `training_step\(\)` method defined."):
3729
model = BoringModel()
3830
model.training_step = None
@@ -70,36 +62,18 @@ def test_eval_loop_config(tmpdir):
7062
"""When either eval step or eval data is missing."""
7163
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
7264

73-
# has val step but no val data
74-
model = BoringModel()
75-
model.val_dataloader = None
76-
with pytest.raises(MisconfigurationException, match=r"No `val_dataloader\(\)` method defined"):
77-
trainer.validate(model)
78-
7965
# has test data but no val step
8066
model = BoringModel()
8167
model.validation_step = None
8268
with pytest.raises(MisconfigurationException, match=r"No `validation_step\(\)` method defined"):
8369
trainer.validate(model)
8470

85-
# has test loop but no test data
86-
model = BoringModel()
87-
model.test_dataloader = None
88-
with pytest.raises(MisconfigurationException, match=r"No `test_dataloader\(\)` method defined"):
89-
trainer.test(model)
90-
9171
# has test data but no test step
9272
model = BoringModel()
9373
model.test_step = None
9474
with pytest.raises(MisconfigurationException, match=r"No `test_step\(\)` method defined"):
9575
trainer.test(model)
9676

97-
# has predict step but no predict data
98-
model = BoringModel()
99-
model.predict_dataloader = None
100-
with pytest.raises(MisconfigurationException, match=r"No `predict_dataloader\(\)` method defined"):
101-
trainer.predict(model)
102-
10377
# has predict data but no predict_step
10478
model = BoringModel()
10579
model.predict_step = None

0 commit comments

Comments
 (0)