Skip to content

Commit d8cb135

Browse files
committed
add testing
1 parent b2b9efe commit d8cb135

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,3 +2124,57 @@ def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path):
21242124

21252125
# save_last=True should always save last.ckpt
21262126
assert (tmp_path / "last.ckpt").exists()
2127+
2128+
2129+
def test_save_last_only_when_checkpoint_saved(tmp_path):
2130+
"""Test that save_last only creates last.ckpt when another checkpoint is actually saved."""
2131+
2132+
class SelectiveModel(BoringModel):
2133+
def __init__(self):
2134+
super().__init__()
2135+
self.validation_step_outputs = []
2136+
2137+
def validation_step(self, batch, batch_idx):
2138+
outputs = super().validation_step(batch, batch_idx)
2139+
epoch = self.trainer.current_epoch
2140+
loss = torch.tensor(1.0 - epoch * 0.1) if epoch % 2 == 0 else torch.tensor(1.0 + epoch * 0.1)
2141+
outputs["val_loss"] = loss
2142+
self.validation_step_outputs.append(outputs)
2143+
return outputs
2144+
2145+
def on_validation_epoch_end(self):
2146+
if self.validation_step_outputs:
2147+
avg_loss = torch.stack([x["val_loss"] for x in self.validation_step_outputs]).mean()
2148+
self.log("val_loss", avg_loss)
2149+
self.validation_step_outputs.clear()
2150+
2151+
model = SelectiveModel()
2152+
2153+
checkpoint_callback = ModelCheckpoint(
2154+
dirpath=tmp_path,
2155+
filename="best-{epoch}-{val_loss:.2f}",
2156+
monitor="val_loss",
2157+
save_last=True,
2158+
save_top_k=1,
2159+
mode="min",
2160+
every_n_epochs=1,
2161+
save_on_train_epoch_end=False,
2162+
)
2163+
2164+
trainer = Trainer(
2165+
max_epochs=4,
2166+
callbacks=[checkpoint_callback],
2167+
logger=False,
2168+
enable_progress_bar=False,
2169+
limit_train_batches=2,
2170+
limit_val_batches=2,
2171+
enable_checkpointing=True,
2172+
)
2173+
2174+
trainer.fit(model)
2175+
2176+
checkpoint_files = list(tmp_path.glob("*.ckpt"))
2177+
checkpoint_names = [f.name for f in checkpoint_files]
2178+
assert "last.ckpt" in checkpoint_names, "last.ckpt should exist since checkpoints were saved"
2179+
expected_files = 2 # best checkpoint + last.ckpt
2180+
assert len(checkpoint_files) == expected_files, f"Expected {expected_files} files, got {len(checkpoint_files)}: {checkpoint_names}"

0 commit comments

Comments
 (0)