Skip to content

Commit 1abf889

Browse files
kaushikb11rohitgr7
authored andcommitted
Remove deprecation warnings being called for on_{task}_dataloader (#9279)
* Avoid deprecation warnings being called when hooks are not implemented * Update tests & changelog * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> Conflicts: CHANGELOG.md pytorch_lightning/core/hooks.py tests/deprecated_api/test_remove_1-7.py tests/trainer/test_trainer.py
1 parent 404dafc commit 1abf889

File tree

3 files changed

+72
-0
lines changed

3 files changed

+72
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1414

1515
- Fixed an issues with export to ONNX format when a model has multiple inputs ([#8800](https://github.com/PyTorchLightning/pytorch-lightning/pull/8800))
1616

17+
- Removed deprecation warnings being called for `on_{task}_dataloader` ([#9279](https://github.com/PyTorchLightning/pytorch-lightning/pull/9279))
18+
1719
- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (
1820
[#8397](https://github.com/PyTorchLightning/pytorch-lightning/pull/8397),
1921
[#8644](https://github.com/PyTorchLightning/pytorch-lightning/pull/8644),

pytorch_lightning/core/hooks.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,15 +669,35 @@ def predict_dataloader(self) -> EVAL_DATALOADERS:
669669
def on_train_dataloader(self) -> None:
670670
"""Called before requesting the train dataloader."""
671671

672+
.. deprecated:: v1.5
673+
:meth:`on_train_dataloader` is deprecated and will be removed in v1.7.0.
674+
Please use :meth:`train_dataloader()` directly.
675+
"""
676+
672677
def on_val_dataloader(self) -> None:
673678
"""Called before requesting the val dataloader."""
674679
680+
.. deprecated:: v1.5
681+
:meth:`on_val_dataloader` is deprecated and will be removed in v1.7.0.
682+
Please use :meth:`val_dataloader()` directly.
683+
"""
684+
675685
def on_test_dataloader(self) -> None:
676686
"""Called before requesting the test dataloader."""
677687

688+
.. deprecated:: v1.5
689+
:meth:`on_test_dataloader` is deprecated and will be removed in v1.7.0.
690+
Please use :meth:`test_dataloader()` directly.
691+
"""
692+
678693
def on_predict_dataloader(self) -> None:
679694
"""Called before requesting the predict dataloader."""
680695
696+
.. deprecated:: v1.5
697+
:meth:`on_predict_dataloader` is deprecated and will be removed in v1.7.0.
698+
Please use :meth:`predict_dataloader()` directly.
699+
"""
700+
681701
def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
682702
"""
683703
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors

tests/trainer/test_trainer.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1896,3 +1896,53 @@ def current_memory():
18961896
trainer_2.fit(model)
18971897

18981898
assert current_memory() <= initial
1899+
1900+
1901+
class TrainerStagesErrorsModel(BoringModel):
1902+
def on_train_start(self) -> None:
1903+
raise Exception("Error during train")
1904+
1905+
def on_validation_start(self) -> None:
1906+
raise Exception("Error during validation")
1907+
1908+
def on_test_start(self) -> None:
1909+
raise Exception("Error during test")
1910+
1911+
def on_predict_start(self) -> None:
1912+
raise Exception("Error during predict")
1913+
1914+
1915+
@pytest.mark.parametrize(
1916+
"accelerator,num_processes",
1917+
[
1918+
(None, 1),
1919+
pytest.param("ddp_cpu", 2, marks=RunIf(skip_windows=True)),
1920+
],
1921+
)
1922+
def test_error_handling_all_stages(tmpdir, accelerator, num_processes):
1923+
model = TrainerStagesErrorsModel()
1924+
trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, num_processes=num_processes, fast_dev_run=True)
1925+
1926+
with pytest.raises(Exception, match=r"Error during train"), patch(
1927+
"pytorch_lightning.Trainer._on_exception"
1928+
) as exception_hook:
1929+
trainer.fit(model)
1930+
exception_hook.assert_called()
1931+
1932+
with pytest.raises(Exception, match=r"Error during validation"), patch(
1933+
"pytorch_lightning.Trainer._on_exception"
1934+
) as exception_hook:
1935+
trainer.validate(model)
1936+
exception_hook.assert_called()
1937+
1938+
with pytest.raises(Exception, match=r"Error during test"), patch(
1939+
"pytorch_lightning.Trainer._on_exception"
1940+
) as exception_hook:
1941+
trainer.test(model)
1942+
exception_hook.assert_called()
1943+
1944+
with pytest.raises(Exception, match=r"Error during predict"), patch(
1945+
"pytorch_lightning.Trainer._on_exception"
1946+
) as exception_hook:
1947+
trainer.predict(model, model.val_dataloader(), return_predictions=False)
1948+
exception_hook.assert_called()

0 commit comments

Comments
 (0)