Skip to content

Commit 744cbbf

Browse files
authored
Merge branch 'master' into feat/ModelCheckpointException
2 parents 0336478 + 6a09f27 commit 744cbbf

File tree

16 files changed

+175
-61
lines changed

16 files changed

+175
-61
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ repos:
5858
#args: ["--write-changes"] # uncomment if you want to get automatic fixing
5959

6060
- repo: https://github.com/PyCQA/docformatter
61-
rev: 06907d0267368b49b9180eed423fae5697c1e909 # todo: fix for docformatter after last 1.7.5
61+
rev: v1.7.7
6262
hooks:
6363
- id: docformatter
6464
additional_dependencies: [tomli]
@@ -70,7 +70,7 @@ repos:
7070
- id: sphinx-lint
7171

7272
- repo: https://github.com/astral-sh/ruff-pre-commit
73-
rev: v0.11.4
73+
rev: v0.12.2
7474
hooks:
7575
# try to fix what is possible
7676
- id: ruff

docs/source-pytorch/common/trainer.rst

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,9 @@ overfit_batches
759759
Uses this much data of the training & validation set.
760760
If the training & validation dataloaders have ``shuffle=True``, Lightning will automatically disable it.
761761

762+
* When set to a value > 0, sequential sampling (no shuffling) is used
763+
* Consistent batches are used for both training and validation across epochs, but training and validation use different sets of data
764+
762765
Useful for quickly debugging or trying to overfit on purpose.
763766

764767
.. testcode::
@@ -769,11 +772,11 @@ Useful for quickly debugging or trying to overfit on purpose.
769772
# use only 1% of the train & val set
770773
trainer = Trainer(overfit_batches=0.01)
771774

772-
# overfit on 10 of the same batches
775+
# overfit on 10 consistent train batches & 10 consistent val batches
773776
trainer = Trainer(overfit_batches=10)
774777

775-
plugins
776-
^^^^^^^
778+
# debug using a single consistent train batch and a single consistent val batch
779+
777780

778781
:ref:`Plugins` allow you to connect arbitrary backends, precision libraries, clusters etc. For example:
779782

@@ -895,7 +898,7 @@ DataSource can be a ``LightningModule`` or a ``LightningDataModule``.
895898
896899
# if 0 (default)
897900
train_loader = model.train_dataloader()
898-
# or if using data module: datamodule.train_dataloader()
901+
# or if using data module: datamodule.train_dataloaders()
899902
for epoch in epochs:
900903
for batch in train_loader:
901904
...

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3939

4040
- Fixed metrics in `RichProgressBar` being updated according to user provided `refresh_rate` ([#21032](https://github.com/Lightning-AI/pytorch-lightning/pull/21032))
4141

42+
43+
- Fix `save_last` behavior in the absence of validation ([#20960](https://github.com/Lightning-AI/pytorch-lightning/pull/20960))
44+
45+
4246
---
4347

4448
## [2.5.2] - 2025-06-20

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
348348
self._save_last_checkpoint(trainer, monitor_candidates)
349349

350350
@override
351+
351352
def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
352353
"""Save a checkpoint when an exception is raised."""
353354
if not self._should_save_on_exception(trainer):
@@ -361,6 +362,13 @@ def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", e
361362
{str(exception)}, saved checkpoint to {filepath}"
362363
)
363364

365+
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
366+
"""Ensure save_last=True is applied when training ends."""
367+
if self.save_last and not self._last_checkpoint_saved:
368+
monitor_candidates = self._monitor_candidates(trainer)
369+
self._save_last_checkpoint(trainer, monitor_candidates)
370+
371+
364372
@override
365373
def state_dict(self) -> dict[str, Any]:
366374
return {

src/lightning/pytorch/core/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str]
274274
scheduler["reduce_on_plateau"] = scheduler.get(
275275
"reduce_on_plateau", isinstance(scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau)
276276
)
277-
if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
277+
if scheduler["reduce_on_plateau"] and scheduler.get("monitor") is None:
278278
raise MisconfigurationException(
279279
"The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
280280
' For example: {"optimizer": optimizer, "lr_scheduler":'

src/lightning/pytorch/demos/transformer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,24 @@ def __init__(
5454

5555
self.ninp = ninp
5656
self.vocab_size = vocab_size
57-
self.src_mask = None
57+
self.src_mask: Optional[Tensor] = None
58+
59+
def generate_square_subsequent_mask(self, size: int) -> Tensor:
60+
"""Generate a square mask for the sequence to prevent future tokens from being seen."""
61+
mask = torch.triu(torch.ones(size, size), diagonal=1)
62+
mask = mask.float().masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0)
63+
return mask
5864

5965
def forward(self, inputs: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> Tensor:
6066
_, t = inputs.shape
6167

62-
# we assume target is already shifted w.r.t. inputs
68+
# Generate source mask to prevent future token leakage
69+
if self.src_mask is None or self.src_mask.size(0) != t:
70+
self.src_mask = self.generate_square_subsequent_mask(t).to(inputs.device)
71+
72+
# Generate target mask if not provided
6373
if mask is None:
64-
mask = torch.tril(torch.ones(t, t, device=inputs.device)) == 1
65-
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0)
74+
mask = self.generate_square_subsequent_mask(t).to(inputs.device)
6675

6776
src = self.pos_encoder(self.embedding(inputs) * math.sqrt(self.ninp))
6877
target = self.pos_encoder(self.embedding(target) * math.sqrt(self.ninp))

src/lightning/pytorch/trainer/connectors/data_connector.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,15 +244,23 @@ def _get_distributed_sampler(
244244

245245

246246
def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None:
247+
"""Resolve overfit batches by disabling shuffling.
248+
249+
When overfit_batches > 0, this function ensures that sequential sampling is used without shuffling for consistent
250+
batches across epochs. Training and validation use different sets of data.
251+
252+
"""
247253
all_have_sequential_sampler = all(
248254
isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler")
249255
)
250256
if all_have_sequential_sampler:
251257
return
258+
252259
rank_zero_warn(
253260
f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling."
254261
f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."
255262
)
263+
256264
updated = [
257265
_update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) if hasattr(dl, "dataset") else dl
258266
for dl in combined_loader.flattened

tests/tests_pytorch/callbacks/test_finetuning_callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def configure_optimizers(self):
109109
model.validation_step = None
110110
callback = TestBackboneFinetuningWarningCallback(unfreeze_backbone_at_epoch=3, verbose=False)
111111

112+
trainer = Trainer(limit_train_batches=1, default_root_dir=tmp_path, callbacks=[callback, chk], max_epochs=2)
112113
with pytest.warns(UserWarning, match="Did you init your optimizer in"):
113-
trainer = Trainer(limit_train_batches=1, default_root_dir=tmp_path, callbacks=[callback, chk], max_epochs=2)
114114
trainer.fit(model)
115115

116116
assert model.backbone.has_been_used

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2086,3 +2086,30 @@ def val_dataloader(self) -> DataLoader:
20862086
trainer_kwargs["max_epochs"] = 4
20872087
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
20882088
trainer.fit(model, ckpt_path=checkpoint_path)
2089+
2090+
2091+
def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path):
2092+
"""Test that save_last=True works correctly when save_on_train_epoch_end=False in a model without validation."""
2093+
2094+
# Remove validation methods to test the edge case
2095+
model = BoringModel()
2096+
model.validation_step = None
2097+
model.val_dataloader = None
2098+
2099+
checkpoint_callback = ModelCheckpoint(
2100+
dirpath=tmp_path,
2101+
save_last=True,
2102+
save_on_train_epoch_end=False,
2103+
)
2104+
2105+
trainer = Trainer(
2106+
max_epochs=2,
2107+
callbacks=[checkpoint_callback],
2108+
logger=False,
2109+
enable_progress_bar=False,
2110+
)
2111+
2112+
trainer.fit(model)
2113+
2114+
# save_last=True should always save last.ckpt
2115+
assert (tmp_path / "last.ckpt").exists()

tests/tests_pytorch/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ def restore_env_variables():
9595
"TF_GRPC_DEFAULT_OPTIONS",
9696
"XLA_FLAGS",
9797
"TORCHINDUCTOR_CACHE_DIR", # leaked by torch.compile
98+
# TensorFlow and TPU related variables
99+
"TF2_BEHAVIOR",
100+
"TPU_ML_PLATFORM",
101+
"TPU_ML_PLATFORM_VERSION",
102+
"LD_LIBRARY_PATH",
103+
"ENABLE_RUNTIME_UPTIME_TELEMETRY",
98104
}
99105
leaked_vars.difference_update(allowlist)
100106
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"

0 commit comments

Comments
 (0)