Skip to content

Commit c0d66b7

Browse files
authored
Merge branch 'master' into uv-for-pytorch-tests
2 parents 1489ff7 + 4824cc1 commit c0d66b7

File tree

10 files changed

+192
-27
lines changed

10 files changed

+192
-27
lines changed

.github/workflows/ci-tests-fabric.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ jobs:
7979
run: pip install -q -r .actions/requirements.txt
8080

8181
- name: Set min. dependencies
82-
if: ${{ matrix.requires == 'oldest' }}
82+
if: ${{ matrix.config.requires == 'oldest' }}
8383
run: |
8484
cd requirements/fabric
8585
pip install -U "lightning-utilities[cli]"
@@ -88,7 +88,7 @@ jobs:
8888
pip install "pyyaml==5.4" --no-build-isolation
8989
9090
- name: Adjust PyTorch versions in requirements files
91-
if: ${{ matrix.requires != 'oldest' }}
91+
if: ${{ matrix.config.requires != 'oldest' }}
9292
run: |
9393
pip install -q -r requirements/ci.txt
9494
python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py

.github/workflows/ci-tests-pytorch.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ jobs:
9393
run: uv pip install -q -r .actions/requirements.txt
9494

9595
- name: Set min. dependencies
96-
if: ${{ matrix.requires == 'oldest' }}
96+
if: ${{ matrix.config.requires == 'oldest' }}
9797
run: |
9898
cd requirements/pytorch
9999
uv pip install -U "lightning-utilities[cli]"
@@ -102,7 +102,7 @@ jobs:
102102
uv pip install "pyyaml==5.4" --no-build-isolation
103103
104104
- name: Adjust PyTorch versions in requirements files
105-
if: ${{ matrix.requires != 'oldest' }}
105+
if: ${{ matrix.config.requires != 'oldest' }}
106106
run: |
107107
uv pip install -q -r requirements/ci.txt
108108
python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py

docs/source-pytorch/common/hooks.rst

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -143,23 +143,24 @@ with the source of each hook indicated:
143143
│ │ │ ├── [LightningModule]
144144
│ │ │ └── [Strategy]
145145
│ │ │
146-
│ │ ├── on_before_zero_grad()
147-
│ │ │ ├── [Callbacks]
148-
│ │ │ └── [LightningModule]
149-
│ │ │
150146
│ │ ├── [Forward Pass - training_step()]
151147
│ │ │ └── [Strategy only]
152148
│ │ │
153-
│ │ ├── on_before_backward()
149+
│ │ ├── on_before_zero_grad()
154150
│ │ │ ├── [Callbacks]
155151
│ │ │ └── [LightningModule]
156152
│ │ │
157-
│ │ ├── [Backward Pass]
158-
│ │ │ └── [Strategy only]
153+
│ │ ├── optimizer_zero_grad()
154+
│ │ │ └── [LightningModule only - optimizer_zero_grad()]
159155
│ │ │
160-
│ │ ├── on_after_backward()
161-
│ │ │ ├── [Callbacks]
162-
│ │ │ └── [LightningModule]
156+
│ │ ├── [Backward Pass - Strategy.backward()]
157+
│ │ │ ├── on_before_backward()
158+
│ │ │ │ ├── [Callbacks]
159+
│ │ │ │ └── [LightningModule]
160+
│ │ │ ├── LightningModule.backward()
161+
│ │ │ └── on_after_backward()
162+
│ │ │ ├── [Callbacks]
163+
│ │ │ └── [LightningModule]
163164
│ │ │
164165
│ │ ├── on_before_optimizer_step()
165166
│ │ │ ├── [Callbacks]
@@ -212,13 +213,14 @@ with the source of each hook indicated:
212213
│ ├── [LightningModule]
213214
│ └── [Strategy]
214215
215-
├── on_fit_end()
216-
│ ├── [Callbacks]
217-
│ ├── [LightningModule]
218-
│ └── [Strategy]
219-
220216
└── teardown(stage="fit")
221-
└── [Callbacks only]
217+
├── [Strategy]
218+
├── on_fit_end()
219+
│ ├── [Callbacks]
220+
│ └── [LightningModule]
221+
├── [LightningDataModule]
222+
├── [Callbacks]
223+
└── [LightningModule]
222224
223225
***********************
224226
Testing Loop Hook Order

src/lightning/pytorch/CHANGELOG.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616
- Added Torch-Tensorrt integration with `LightningModule` ([#20808](https://github.com/Lightning-AI/pytorch-lightning/pull/20808))
1717

1818

19+
- Added `PossibleUserWarning` that is raised if modules are in eval mode when training starts ([#21146](https://github.com/Lightning-AI/pytorch-lightning/pull/21146))
20+
1921
### Changed
2022

2123
- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))
@@ -28,14 +30,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2830

2931
### Fixed
3032

33+
- Fixed `LightningCLI` not using `ckpt_path` hyperparameters to instantiate classes ([#21116](https://github.com/Lightning-AI/pytorch-lightning/pull/21116))
34+
35+
3136
- Fixed callbacks by defer step/time-triggered `ModelCheckpoint` saves until validation metrics are available ([#21106](https://github.com/Lightning-AI/pytorch-lightning/pull/21106))
3237

3338

3439
- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))
3540

3641

37-
---
42+
- Fixed `TQDMProgressBar` not resetting correctly when using both a finite and iterable dataloader ([#21147](https://github.com/Lightning-AI/pytorch-lightning/pull/21147))
3843

44+
---
3945

4046
## [2.5.4] - 2025-08-29
4147

@@ -47,7 +53,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4753
- Fixed misalignment column while using rich model summary in `DeepSpeedstrategy` ([#21100](https://github.com/Lightning-AI/pytorch-lightning/pull/21100))
4854
- Fixed `RichProgressBar` crashing when sanity checking using val dataloader with 0 len ([#21108](https://github.com/Lightning-AI/pytorch-lightning/pull/21108))
4955

50-
5156
## [2.5.3] - 2025-08-13
5257

5358
### Changed

src/lightning/pytorch/callbacks/progress/tqdm_progress.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,9 @@ def on_train_start(self, *_: Any) -> None:
265265
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
266266
if self._leave:
267267
self.train_progress_bar = self.init_train_tqdm()
268-
self.train_progress_bar.reset(convert_inf(self.total_train_batches))
268+
total = convert_inf(self.total_train_batches)
269+
self.train_progress_bar.reset()
270+
self.train_progress_bar.total = total
269271
self.train_progress_bar.initial = 0
270272
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
271273

@@ -306,7 +308,9 @@ def on_validation_batch_start(
306308
if not self.has_dataloader_changed(dataloader_idx):
307309
return
308310

309-
self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader))
311+
total = convert_inf(self.total_val_batches_current_dataloader)
312+
self.val_progress_bar.reset()
313+
self.val_progress_bar.total = total
310314
self.val_progress_bar.initial = 0
311315
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
312316
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")
@@ -348,7 +352,9 @@ def on_test_batch_start(
348352
if not self.has_dataloader_changed(dataloader_idx):
349353
return
350354

351-
self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
355+
total = convert_inf(self.total_test_batches_current_dataloader)
356+
self.test_progress_bar.reset()
357+
self.test_progress_bar.total = total
352358
self.test_progress_bar.initial = 0
353359
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")
354360

@@ -387,7 +393,9 @@ def on_predict_batch_start(
387393
if not self.has_dataloader_changed(dataloader_idx):
388394
return
389395

390-
self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
396+
total = convert_inf(self.total_predict_batches_current_dataloader)
397+
self.predict_progress_bar.reset()
398+
self.predict_progress_bar.total = total
391399
self.predict_progress_bar.initial = 0
392400
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")
393401

src/lightning/pytorch/cli.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import sys
1717
from collections.abc import Iterable
1818
from functools import partial, update_wrapper
19+
from pathlib import Path
1920
from types import MethodType
2021
from typing import Any, Callable, Optional, TypeVar, Union
2122

@@ -397,6 +398,7 @@ def __init__(
397398
main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs)
398399
self.setup_parser(run, main_kwargs, subparser_kwargs)
399400
self.parse_arguments(self.parser, args)
401+
self._parse_ckpt_path()
400402

401403
self.subcommand = self.config["subcommand"] if run else None
402404

@@ -551,6 +553,24 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
551553
else:
552554
self.config = parser.parse_args(args)
553555

556+
def _parse_ckpt_path(self) -> None:
557+
"""If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
558+
if not self.config.get("subcommand"):
559+
return
560+
ckpt_path = self.config[self.config.subcommand].get("ckpt_path")
561+
if ckpt_path and Path(ckpt_path).is_file():
562+
ckpt = torch.load(ckpt_path, weights_only=True, map_location="cpu")
563+
hparams = ckpt.get("hyper_parameters", {})
564+
hparams.pop("_instantiator", None)
565+
if not hparams:
566+
return
567+
hparams = {self.config.subcommand: {"model": hparams}}
568+
try:
569+
self.config = self.parser.parse_object(hparams, self.config)
570+
except SystemExit:
571+
sys.stderr.write("Parsing of ckpt_path hyperparameters failed!\n")
572+
raise
573+
554574
def _dump_config(self) -> None:
555575
if hasattr(self, "config_dump"):
556576
return

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,9 @@ def on_run_start(self) -> None:
414414
self.epoch_loop.val_loop.setup_data()
415415
trainer.training = True
416416

417+
# Check for modules in eval mode at training start
418+
self._warn_if_modules_in_eval_mode()
419+
417420
call._call_callback_hooks(trainer, "on_train_start")
418421
call._call_lightning_module_hook(trainer, "on_train_start")
419422
call._call_strategy_hook(trainer, "on_train_start")
@@ -515,6 +518,19 @@ def on_load_checkpoint(self, state_dict: dict) -> None:
515518
self._combined_loader_states_to_load = state_dict.get("combined_loader", [])
516519
super().on_load_checkpoint(state_dict)
517520

521+
def _warn_if_modules_in_eval_mode(self) -> None:
522+
"""Warn if any modules are in eval mode at the start of training."""
523+
model = self.trainer.lightning_module
524+
eval_modules = [name for name, module in model.named_modules() if not module.training]
525+
526+
if eval_modules:
527+
rank_zero_warn(
528+
f"Found {len(eval_modules)} module(s) in eval mode at the start of training."
529+
" This may lead to unexpected behavior during training. If this is intentional,"
530+
" you can ignore this warning.",
531+
category=PossibleUserWarning,
532+
)
533+
518534
def _should_accumulate(self) -> bool:
519535
"""Whether the gradients should be accumulated."""
520536
return self.epoch_loop._should_accumulate()

tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,3 +812,50 @@ def test_tqdm_leave(leave, tmp_path):
812812
)
813813
trainer.fit(model)
814814
assert pbar.init_train_tqdm.call_count == (4 if leave else 1)
815+
816+
817+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
818+
def test_tqdm_progress_bar_reset_behavior(tmp_path):
819+
"""Test that progress bars call reset() without parameters and set total separately."""
820+
model = BoringModel()
821+
822+
class ResetTrackingTqdm(MockTqdm):
823+
def __init__(self, *args, **kwargs):
824+
super().__init__(*args, **kwargs)
825+
self.reset_calls_with_params = []
826+
827+
def reset(self, total=None):
828+
self.reset_calls_with_params.append(total)
829+
super().reset(total)
830+
831+
trainer = Trainer(
832+
default_root_dir=tmp_path,
833+
limit_train_batches=2,
834+
limit_val_batches=2,
835+
max_epochs=1,
836+
logger=False,
837+
enable_checkpointing=False,
838+
)
839+
840+
pbar = trainer.progress_bar_callback
841+
842+
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", ResetTrackingTqdm):
843+
trainer.fit(model)
844+
845+
train_bar = pbar.train_progress_bar
846+
assert None in train_bar.reset_calls_with_params, (
847+
f"train reset() should be called without parameters, got calls: {train_bar.reset_calls_with_params}"
848+
)
849+
# Verify that total was set separately to the expected value
850+
assert 2 in train_bar.total_values, (
851+
f"train total should be set to 2 after reset(), got total_values: {train_bar.total_values}"
852+
)
853+
# Verify that validation progress bar reset() was called without parameters
854+
val_bar = pbar.val_progress_bar
855+
assert None in val_bar.reset_calls_with_params, (
856+
f"validation reset() should be called without parameters, got calls: {val_bar.reset_calls_with_params}"
857+
)
858+
# Verify that total was set separately to the expected value
859+
assert 2 in val_bar.total_values, (
860+
f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}"
861+
)

tests/tests_pytorch/loops/test_training_loop.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# limitations under the License.
1414
import itertools
1515
import logging
16+
import warnings
1617
from unittest.mock import Mock
1718

1819
import pytest
1920
import torch
2021
from torch.utils.data import DataLoader
2122

23+
from lightning.fabric.utilities.warnings import PossibleUserWarning
2224
from lightning.pytorch import Trainer, seed_everything
2325
from lightning.pytorch.demos.boring_classes import BoringModel
2426
from lightning.pytorch.loops import _FitLoop
@@ -277,3 +279,29 @@ def __iter__(self):
277279

278280
# assert progress bar callback uses correct total steps
279281
assert pbar.train_progress_bar.total == max_steps
282+
283+
284+
@pytest.mark.parametrize("warn", [True, False])
285+
def test_eval_mode_warning(tmp_path, warn):
286+
"""Test that a warning is raised if any module is in eval mode at the start of training."""
287+
model = BoringModel()
288+
if warn:
289+
model.some_eval_module = torch.nn.Linear(32, 16)
290+
model.some_eval_module.eval()
291+
292+
trainer = Trainer(
293+
default_root_dir=tmp_path,
294+
max_epochs=1,
295+
)
296+
297+
if warn:
298+
with pytest.warns(PossibleUserWarning):
299+
trainer.fit(model)
300+
else:
301+
with warnings.catch_warnings(record=True) as warning_list:
302+
warnings.simplefilter("always")
303+
trainer.fit(model)
304+
eval_warnings = [
305+
w for w in warning_list if issubclass(w.category, PossibleUserWarning) and "eval mode" in str(w.message)
306+
]
307+
assert len(eval_warnings) == 0, "Expected no eval mode warnings"

tests/tests_pytorch/test_cli.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import operator
1818
import os
1919
import sys
20-
from contextlib import ExitStack, contextmanager, redirect_stdout
20+
from contextlib import ExitStack, contextmanager, redirect_stderr, redirect_stdout
2121
from io import StringIO
2222
from pathlib import Path
2323
from typing import Callable, Optional, Union
@@ -487,6 +487,45 @@ def test_lightning_cli_print_config():
487487
assert outval["ckpt_path"] is None
488488

489489

490+
class BoringCkptPathModel(BoringModel):
491+
def __init__(self, out_dim: int = 2, hidden_dim: int = 2) -> None:
492+
super().__init__()
493+
self.save_hyperparameters()
494+
self.layer = torch.nn.Linear(32, out_dim)
495+
496+
497+
def test_lightning_cli_ckpt_path_argument_hparams(cleandir):
498+
class CkptPathCLI(LightningCLI):
499+
def add_arguments_to_parser(self, parser):
500+
parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2)
501+
502+
cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"]
503+
with mock.patch("sys.argv", ["any.py"] + cli_args):
504+
cli = CkptPathCLI(BoringCkptPathModel)
505+
506+
assert cli.config.fit.model.out_dim == 3
507+
assert cli.config.fit.model.hidden_dim == 6
508+
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
509+
assert hparams_path.is_file()
510+
hparams = yaml.safe_load(hparams_path.read_text())
511+
assert hparams["out_dim"] == 3
512+
assert hparams["hidden_dim"] == 6
513+
514+
checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))
515+
cli_args = ["predict", f"--ckpt_path={checkpoint_path}"]
516+
with mock.patch("sys.argv", ["any.py"] + cli_args):
517+
cli = CkptPathCLI(BoringCkptPathModel)
518+
519+
assert cli.config.predict.model.out_dim == 3
520+
assert cli.config.predict.model.hidden_dim == 6
521+
assert cli.config_init.predict.model.layer.out_features == 3
522+
523+
err = StringIO()
524+
with mock.patch("sys.argv", ["any.py"] + cli_args), redirect_stderr(err), pytest.raises(SystemExit):
525+
cli = LightningCLI(BoringModel)
526+
assert "Parsing of ckpt_path hyperparameters failed" in err.getvalue()
527+
528+
490529
def test_lightning_cli_submodules(cleandir):
491530
class MainModule(BoringModel):
492531
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):

0 commit comments

Comments
 (0)