Skip to content

Commit 812ffde

Browse files
authored
Fix save_last type annotation for ModelCheckpoint (#19808)
1 parent 7668a6b commit 812ffde

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5757

5858
- Fixed `WandbLogger.log_hyperparameters()` raising an error if hyperparameters are not JSON serializable ([#19769](https://github.com/Lightning-AI/pytorch-lightning/pull/19769))
5959

60-
-
60+
61+
- Fixed an issue with the LightningCLI not being able to set the `ModelCheckpoint(save_last=...)` argument ([#19808](https://github.com/Lightning-AI/pytorch-lightning/pull/19808))
6162

6263

6364
## [2.2.2] - 2024-04-11

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from copy import deepcopy
2828
from datetime import timedelta
2929
from pathlib import Path
30-
from typing import Any, Dict, Literal, Optional, Set
30+
from typing import Any, Dict, Literal, Optional, Set, Union
3131
from weakref import proxy
3232

3333
import torch
@@ -216,7 +216,7 @@ def __init__(
216216
filename: Optional[str] = None,
217217
monitor: Optional[str] = None,
218218
verbose: bool = False,
219-
save_last: Optional[Literal[True, False, "link"]] = None,
219+
save_last: Optional[Union[bool, Literal["link"]]] = None,
220220
save_top_k: int = 1,
221221
save_weights_only: bool = False,
222222
mode: str = "min",

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import time
1919
from argparse import Namespace
2020
from datetime import timedelta
21+
from inspect import signature
2122
from pathlib import Path
2223
from typing import Union
2324
from unittest import mock
@@ -28,6 +29,7 @@
2829
import pytest
2930
import torch
3031
import yaml
32+
from jsonargparse import ArgumentParser
3133
from lightning.fabric.utilities.cloud_io import _load as pl_load
3234
from lightning.pytorch import Trainer, seed_everything
3335
from lightning.pytorch.callbacks import ModelCheckpoint
@@ -1601,3 +1603,24 @@ def test_expand_home():
16011603
# it is possible to have a folder with the name `~`
16021604
checkpoint = ModelCheckpoint(dirpath="./~/checkpoints")
16031605
assert checkpoint.dirpath == str(Path.cwd() / "~" / "checkpoints")
1606+
1607+
1608+
@pytest.mark.parametrize(
1609+
("val", "expected"),
1610+
[
1611+
("yes", True),
1612+
("True", True),
1613+
("true", True),
1614+
("no", False),
1615+
("false", False),
1616+
("False", False),
1617+
("link", "link"),
1618+
],
1619+
)
1620+
def test_save_last_cli(val, expected):
1621+
"""Test that the CLI can parse the `save_last` argument correctly (composed type)."""
1622+
annot = signature(ModelCheckpoint).parameters["save_last"].annotation
1623+
parser = ArgumentParser()
1624+
parser.add_argument("--a", type=annot)
1625+
args = parser.parse_args(["--a", val])
1626+
assert args.a == expected

0 commit comments

Comments
 (0)