Skip to content

Commit 78d4037

Browse files
Update tests/tests_pytorch/callbacks/test_weight_averaging.py
Thanks @GdoongMathew! Co-authored-by: GdoongMathew <[email protected]>
1 parent b982d37 commit 78d4037

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/tests_pytorch/callbacks/test_weight_averaging.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,9 +426,9 @@ def test_ema_weight_averaging_checkpoint_save_load(tmp_path):
426426
# Resume from checkpoint
427427
model2 = TestModel()
428428
callback2 = EMAWeightAveraging(decay=0.99, update_every_n_steps=2)
429-
checkpoint_path = str(tmp_path / "lightning_logs" / "version_0" / "checkpoints" / "*.ckpt")
430-
431-
_train(model2, dataset, tmp_path, callback2, checkpoint_path=checkpoint_path)
429+
import glob # should be at the top
430+
_train(model2, dataset, tmp_path, callback2,
431+
checkpoint_path=glob.glob((tmp_path / "checkpoints" / "*.ckpt").as_posix())[0])
432432

433433
assert callback2._average_model is not None
434434

0 commit comments

Comments
 (0)