Skip to content

Commit 3076ea1

Browse files
committed
add test to test prefix for checkpoint name
1 parent 8ba6381 commit 3076ea1

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,12 @@ def test_model_checkpoint_format_checkpoint_name(tmp_path, monkeypatch):
453453
ckpt_name = ckpt.format_checkpoint_name({}, ver=3)
454454
assert ckpt_name == str(tmp_path / "name-v3.ckpt")
455455

456+
# with prefix
457+
ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath=tmp_path, filename="name").format_checkpoint_name(
458+
{}, prefix="test"
459+
)
460+
assert ckpt_name == str(tmp_path / "test-name.ckpt")
461+
456462
# using slashes
457463
ckpt = ModelCheckpoint(monitor="early_stop_on", dirpath=None, filename="{epoch}_{val/loss:.5f}")
458464
ckpt_name = ckpt.format_checkpoint_name({"epoch": 4, "val/loss": 0.03})

0 commit comments

Comments
 (0)