Skip to content

Commit b4c4401

Browse files
committed
Implement test to reproduce the issue
1 parent 7d61691 commit b4c4401

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,3 +1666,30 @@ def val_dataloader(self) -> DataLoader:
16661666
trainer_kwargs["max_epochs"] = 4
16671667
trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs))
16681668
trainer.fit(model, ckpt_path=checkpoint_path)
1669+
1670+
1671+
def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path):
1672+
"""Test that save_last=True when save_on_train_epoch_end=False"""
1673+
1674+
# Remove validation methods to reproduce the bug
1675+
model = BoringModel()
1676+
model.validation_step = None
1677+
model.val_dataloader = None
1678+
1679+
checkpoint_callback = ModelCheckpoint(
1680+
dirpath=tmp_path,
1681+
save_last=True,
1682+
save_on_train_epoch_end=False,
1683+
)
1684+
1685+
trainer = Trainer(
1686+
max_epochs=2,
1687+
callbacks=[checkpoint_callback],
1688+
logger=False,
1689+
enable_progress_bar=False,
1690+
)
1691+
1692+
trainer.fit(model)
1693+
1694+
# save_last=True should always save last.ckpt
1695+
assert (tmp_path / "last.ckpt").exists()

0 commit comments

Comments
 (0)