Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added Torch-Tensorrt integration with `LightningModule` ([#20808](https://github.com/Lightning-AI/pytorch-lightning/pull/20808))


- Added `PossibleUserWarning` that is raised if modules are in eval mode when training starts ([#21146](https://github.com/Lightning-AI/pytorch-lightning/pull/21146))

### Changed

- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))
Expand Down
16 changes: 16 additions & 0 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,9 @@ def on_run_start(self) -> None:
self.epoch_loop.val_loop.setup_data()
trainer.training = True

# Check for modules in eval mode at training start
self._warn_if_modules_in_eval_mode()

call._call_callback_hooks(trainer, "on_train_start")
call._call_lightning_module_hook(trainer, "on_train_start")
call._call_strategy_hook(trainer, "on_train_start")
Expand Down Expand Up @@ -515,6 +518,19 @@ def on_load_checkpoint(self, state_dict: dict) -> None:
self._combined_loader_states_to_load = state_dict.get("combined_loader", [])
super().on_load_checkpoint(state_dict)

def _warn_if_modules_in_eval_mode(self) -> None:
"""Warn if any modules are in eval mode at the start of training."""
model = self.trainer.lightning_module
eval_modules = [name for name, module in model.named_modules() if not module.training]

if eval_modules:
rank_zero_warn(
f"Found {len(eval_modules)} module(s) in eval mode at the start of training."
" This may lead to unexpected behavior during training. If this is intentional,"
" you can ignore this warning.",
category=PossibleUserWarning,
)

def _should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated."""
return self.epoch_loop._should_accumulate()
Expand Down
28 changes: 28 additions & 0 deletions tests/tests_pytorch/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.
import itertools
import logging
import warnings
from unittest.mock import Mock

import pytest
import torch
from torch.utils.data import DataLoader

from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops import _FitLoop
Expand Down Expand Up @@ -277,3 +279,29 @@ def __iter__(self):

# assert progress bar callback uses correct total steps
assert pbar.train_progress_bar.total == max_steps


@pytest.mark.parametrize("warn", [True, False])
def test_eval_mode_warning(tmp_path, warn):
"""Test that a warning is raised if any module is in eval mode at the start of training."""
model = BoringModel()
if warn:
model.some_eval_module = torch.nn.Linear(32, 16)
model.some_eval_module.eval()

trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=1,
)

if warn:
with pytest.warns(PossibleUserWarning):
trainer.fit(model)
else:
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
trainer.fit(model)
eval_warnings = [
w for w in warning_list if issubclass(w.category, PossibleUserWarning) and "eval mode" in str(w.message)
]
assert len(eval_warnings) == 0, "Expected no eval mode warnings"
Loading