Skip to content

Commit 8588032

Browse files
Fix ModelCheckpoint tests from incomplete PR (#19205)
* Update src/lightning/pytorch/trainer/trainer.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6ea0e2d commit 8588032

File tree

3 files changed

+7
-9
lines changed

3 files changed

+7
-9
lines changed

tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def test_default_checkpoint_freq(save_mock, tmpdir, epochs: int, val_check_inter
5454
@pytest.mark.parametrize(
5555
("k", "epochs", "val_check_interval", "expected"), [(1, 1, 1.0, 1), (2, 2, 1.0, 2), (2, 1, 0.25, 4), (2, 2, 0.3, 6)]
5656
)
57-
@pytest.mark.parametrize("save_last", [False, True])
58-
def test_top_k(save_mock, tmpdir, k: int, epochs: int, val_check_interval: float, expected: int, save_last: bool):
57+
@pytest.mark.parametrize("save_last", [False, True, "link"])
58+
def test_top_k(save_mock, tmpdir, k, epochs, val_check_interval, expected, save_last):
5959
class TestModel(BoringModel):
6060
def __init__(self):
6161
super().__init__()
@@ -79,8 +79,8 @@ def training_step(self, batch, batch_idx):
7979
)
8080
trainer.fit(model)
8181

82-
if save_last:
83-
expected = expected
82+
# save_last=True: last epochs are saved every step (so double the save calls)
83+
expected = expected * 2 if save_last is True else expected
8484
assert save_mock.call_count == expected
8585

8686

tests/tests_pytorch/models/test_restore.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,9 @@ def get_trainer_args():
311311
"best_k_models",
312312
"kth_best_model_path",
313313
"kth_value",
314+
"last_model_path",
314315
):
315316
assert getattr(before, attribute) == getattr(after, attribute), f"{attribute}"
316-
# `before.last_model_path` is a symlink pointing to a checkpoint saved before that symlink was created,
317-
# hence reloading that checkpoint will restore `after.last_model_path = ""`
318-
assert after.last_model_path == ""
319317

320318

321319
@RunIf(sklearn=True)

tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_checkpoint_plugin_called(tmpdir):
6161
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt"}
6262
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2.ckpt"
6363
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last.ckpt"
64-
assert checkpoint_plugin.save_checkpoint.call_count == 2
64+
assert checkpoint_plugin.save_checkpoint.call_count == 4
6565
assert checkpoint_plugin.remove_checkpoint.call_count == 1
6666

6767
trainer.test(model, ckpt_path=ck.last_model_path)
@@ -88,7 +88,7 @@ def test_checkpoint_plugin_called(tmpdir):
8888
assert ckpt_files == {"epoch=1-step=2.ckpt", "last.ckpt", "epoch=1-step=2-v1.ckpt", "last-v1.ckpt"}
8989
assert trainer.checkpoint_callback.best_model_path == tmpdir / "epoch=1-step=2-v1.ckpt"
9090
assert trainer.checkpoint_callback.last_model_path == tmpdir / "last-v1.ckpt"
91-
assert checkpoint_plugin.save_checkpoint.call_count == 2
91+
assert checkpoint_plugin.save_checkpoint.call_count == 4
9292
assert checkpoint_plugin.remove_checkpoint.call_count == 1
9393

9494
trainer.test(model, ckpt_path=ck.last_model_path)

0 commit comments

Comments
 (0)