diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 176e34273d776..96fcf3daa6396 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Torch-Tensorrt integration with `LightningModule` ([#20808](https://github.com/Lightning-AI/pytorch-lightning/pull/20808)) +- Added `PossibleUserWarning` that is raised if modules are in eval mode when training starts ([#21146](https://github.com/Lightning-AI/pytorch-lightning/pull/21146)) + ### Changed - Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580)) diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 31d6724a043a3..f25c33359a78a 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -414,6 +414,9 @@ def on_run_start(self) -> None: self.epoch_loop.val_loop.setup_data() trainer.training = True + # Check for modules in eval mode at training start + self._warn_if_modules_in_eval_mode() + call._call_callback_hooks(trainer, "on_train_start") call._call_lightning_module_hook(trainer, "on_train_start") call._call_strategy_hook(trainer, "on_train_start") @@ -515,6 +518,19 @@ def on_load_checkpoint(self, state_dict: dict) -> None: self._combined_loader_states_to_load = state_dict.get("combined_loader", []) super().on_load_checkpoint(state_dict) + def _warn_if_modules_in_eval_mode(self) -> None: + """Warn if any modules are in eval mode at the start of training.""" + model = self.trainer.lightning_module + eval_modules = [name for name, module in model.named_modules() if not module.training] + + if eval_modules: + rank_zero_warn( + f"Found {len(eval_modules)} module(s) in eval mode at the start of training." + " This may lead to unexpected behavior during training. If this is intentional," + " you can ignore this warning.", + category=PossibleUserWarning, + ) + def _should_accumulate(self) -> bool: """Whether the gradients should be accumulated.""" return self.epoch_loop._should_accumulate() diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index e3a4c37f6a284..f5aaa18095fc5 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -13,12 +13,14 @@ # limitations under the License. import itertools import logging +import warnings from unittest.mock import Mock import pytest import torch from torch.utils.data import DataLoader +from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops import _FitLoop @@ -277,3 +279,29 @@ def __iter__(self): # assert progress bar callback uses correct total steps assert pbar.train_progress_bar.total == max_steps + + +@pytest.mark.parametrize("warn", [True, False]) +def test_eval_mode_warning(tmp_path, warn): + """Test that a warning is raised if any module is in eval mode at the start of training.""" + model = BoringModel() + if warn: + model.some_eval_module = torch.nn.Linear(32, 16) + model.some_eval_module.eval() + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + ) + + if warn: + with pytest.warns(PossibleUserWarning): + trainer.fit(model) + else: + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + trainer.fit(model) + eval_warnings = [ + w for w in warning_list if issubclass(w.category, PossibleUserWarning) and "eval mode" in str(w.message) + ] + assert len(eval_warnings) == 0, "Expected no eval mode warnings"