From 759b93201c16a446f4a2afbf91b4035a644e8798 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 1 Sep 2025 14:43:27 +0200 Subject: [PATCH 1/4] add warning --- src/lightning/pytorch/loops/fit_loop.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 31d6724a043a3..bf34b15cdc6eb 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -414,6 +414,9 @@ def on_run_start(self) -> None: self.epoch_loop.val_loop.setup_data() trainer.training = True + # Check for modules in eval mode at training start + self._warn_if_modules_in_eval_mode() + call._call_callback_hooks(trainer, "on_train_start") call._call_lightning_module_hook(trainer, "on_train_start") call._call_strategy_hook(trainer, "on_train_start") @@ -515,6 +518,19 @@ def on_load_checkpoint(self, state_dict: dict) -> None: self._combined_loader_states_to_load = state_dict.get("combined_loader", []) super().on_load_checkpoint(state_dict) + def _warn_if_modules_in_eval_mode(self) -> None: + """Warn if any modules are in eval mode at the start of training.""" + model = self.trainer.lightning_module + eval_modules = [name for name, module in model.named_modules() if not module.training] + + if eval_modules: + rank_zero_warn( + f"Found {len(eval_modules)} module(s) in eval mode at the start of training." + f" This may lead to unexpected behavior during training. If this is intentional," + " you can ignore this warning.", + category=PossibleUserWarning, + ) + def _should_accumulate(self) -> bool: """Whether the gradients should be accumulated.""" return self.epoch_loop._should_accumulate() From 3d09f40e375a8eadc3a8d6b39b5a3f6639470393 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 1 Sep 2025 14:43:49 +0200 Subject: [PATCH 2/4] tests --- .../tests_pytorch/loops/test_training_loop.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index e3a4c37f6a284..f5aaa18095fc5 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -13,12 +13,14 @@ # limitations under the License. import itertools import logging +import warnings from unittest.mock import Mock import pytest import torch from torch.utils.data import DataLoader +from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops import _FitLoop @@ -277,3 +279,29 @@ def __iter__(self): # assert progress bar callback uses correct total steps assert pbar.train_progress_bar.total == max_steps + + +@pytest.mark.parametrize("warn", [True, False]) +def test_eval_mode_warning(tmp_path, warn): + """Test that a warning is raised if any module is in eval mode at the start of training.""" + model = BoringModel() + if warn: + model.some_eval_module = torch.nn.Linear(32, 16) + model.some_eval_module.eval() + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + ) + + if warn: + with pytest.warns(PossibleUserWarning): + trainer.fit(model) + else: + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + trainer.fit(model) + eval_warnings = [ + w for w in warning_list if issubclass(w.category, PossibleUserWarning) and "eval mode" in str(w.message) + ] + assert len(eval_warnings) == 0, "Expected no eval mode warnings" From 5cfca21097c482fc453f2ae65d9b514e0d326a69 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 2 Sep 2025 07:04:40 +0200 Subject: [PATCH 3/4] changelog --- src/lightning/pytorch/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 176e34273d776..96fcf3daa6396 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Torch-Tensorrt integration with `LightningModule` ([#20808](https://github.com/Lightning-AI/pytorch-lightning/pull/20808)) +- Added `PossibleUserWarning` that is raised if modules are in eval mode when training starts ([#21146](https://github.com/Lightning-AI/pytorch-lightning/pull/21146)) + ### Changed - Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580)) From 5ffe89ff2ee94d20f35e876099781ae3404b7121 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 2 Sep 2025 10:05:43 +0200 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Deependu --- src/lightning/pytorch/loops/fit_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index bf34b15cdc6eb..f25c33359a78a 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -526,7 +526,7 @@ def _warn_if_modules_in_eval_mode(self) -> None: if eval_modules: rank_zero_warn( f"Found {len(eval_modules)} module(s) in eval mode at the start of training." - f" This may lead to unexpected behavior during training. If this is intentional," + " This may lead to unexpected behavior during training. If this is intentional," " you can ignore this warning.", category=PossibleUserWarning, )