Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))
- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896))


- Fixed preventing recursive symlink creation iwhen `save_last='link'` and `save_top_k=-1` ([#21186](https://github.com/Lightning-AI/pytorch-lightning/pull/21186))


### Removed
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:

@staticmethod
def _link_checkpoint(trainer: "pl.Trainer", filepath: str, linkpath: str) -> None:
if trainer.is_global_zero:
if trainer.is_global_zero and os.path.abspath(filepath) != os.path.abspath(linkpath):
if os.path.islink(linkpath) or os.path.isfile(linkpath):
os.remove(linkpath)
elif os.path.isdir(linkpath):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import os
from datetime import timedelta

import pytest
Expand All @@ -9,6 +10,7 @@

from lightning.pytorch import LightningModule, Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel


class TinyDataset(Dataset):
Expand Down Expand Up @@ -206,3 +208,24 @@ def test_model_checkpoint_defer_until_next_validation_when_val_every_2_epochs(tm
expected = max(val_scores) # last/maximum value occurs at final validation epoch
actual = float(ckpt.best_model_score)
assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)


def test_model_checkpoint_save_last_link_symlink_bug(tmp_path):
"""Reproduce the bug where save_last='link' and save_top_k=-1 creates a recursive symlink."""
trainer = Trainer(
default_root_dir=tmp_path,
max_epochs=2,
callbacks=[ModelCheckpoint(dirpath=tmp_path, every_n_epochs=10, save_last="link", save_top_k=-1)],
enable_checkpointing=True,
enable_model_summary=False,
logger=False,
)

model = BoringModel()
trainer.fit(model)

last_ckpt = tmp_path / "last.ckpt"
assert last_ckpt.exists()
# With the fix, if a symlink exists, it should not point to itself (preventing recursion)
if os.path.islink(str(last_ckpt)):
assert os.readlink(str(last_ckpt)) != str(last_ckpt)
Loading