|
28 | 28 | import pytest |
29 | 29 | import torch |
30 | 30 | import yaml |
31 | | -from tests_pytorch.helpers.runif import RunIf |
32 | | -from torch import optim |
33 | | - |
34 | 31 | from lightning.fabric.utilities.cloud_io import _load as pl_load |
35 | 32 | from lightning.pytorch import Trainer, seed_everything |
36 | 33 | from lightning.pytorch.callbacks import ModelCheckpoint |
|
39 | 36 | from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger |
40 | 37 | from lightning.pytorch.utilities.exceptions import MisconfigurationException |
41 | 38 | from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE |
| 39 | +from torch import optim |
| 40 | + |
| 41 | +from tests_pytorch.helpers.runif import RunIf |
42 | 42 |
|
43 | 43 | if _OMEGACONF_AVAILABLE: |
44 | 44 | from omegaconf import Container, OmegaConf |
@@ -888,13 +888,11 @@ def test_default_checkpoint_behavior(tmp_path): |
888 | 888 | assert len(results) == 1 |
889 | 889 | save_dir = tmp_path / "checkpoints" |
890 | 890 | save_weights_only = trainer.checkpoint_callback.save_weights_only |
891 | | - save_mock.assert_has_calls( |
892 | | - [ |
893 | | - call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only), |
894 | | - call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only), |
895 | | - call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only), |
896 | | - ] |
897 | | - ) |
| 891 | + save_mock.assert_has_calls([ |
| 892 | + call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only), |
| 893 | + call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only), |
| 894 | + call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only), |
| 895 | + ]) |
898 | 896 | ckpts = os.listdir(save_dir) |
899 | 897 | assert len(ckpts) == 1 |
900 | 898 | assert ckpts[0] == "epoch=2-step=15.ckpt" |
@@ -1478,8 +1476,6 @@ def test_save_last_versioning(tmp_path): |
1478 | 1476 | assert all(not os.path.islink(tmp_path / path) for path in set(os.listdir(tmp_path))) |
1479 | 1477 |
|
1480 | 1478 |
|
1481 | | - |
1482 | | - |
1483 | 1479 | def test_none_monitor_saves_correct_best_model_path(tmp_path): |
1484 | 1480 | mc = ModelCheckpoint(dirpath=tmp_path, monitor=None) |
1485 | 1481 | trainer = Trainer(callbacks=mc) |
|
0 commit comments