Skip to content

Commit 0b3322d

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 32db7d4 commit 0b3322d

File tree

2 files changed

+11
-16
lines changed

2 files changed

+11
-16
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,9 @@
3030
from typing import Any, Literal, Optional, Union, cast
3131
from weakref import proxy
3232

33+
import pytorch_lightning as pl
3334
import torch
3435
import yaml
35-
from torch import Tensor
36-
from typing_extensions import override
37-
38-
import pytorch_lightning as pl
3936
from lightning_fabric.utilities.cloud_io import (
4037
_is_dir,
4138
_is_local_file_protocol,
@@ -50,6 +47,8 @@
5047
rank_zero_warn,
5148
)
5249
from pytorch_lightning.utilities.types import STEP_OUTPUT
50+
from torch import Tensor
51+
from typing_extensions import override
5352

5453
log = logging.getLogger(__name__)
5554
warning_cache = WarningCache()

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@
2828
import pytest
2929
import torch
3030
import yaml
31-
from tests_pytorch.helpers.runif import RunIf
32-
from torch import optim
33-
3431
from lightning.fabric.utilities.cloud_io import _load as pl_load
3532
from lightning.pytorch import Trainer, seed_everything
3633
from lightning.pytorch.callbacks import ModelCheckpoint
@@ -39,6 +36,9 @@
3936
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
4037
from lightning.pytorch.utilities.exceptions import MisconfigurationException
4138
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
39+
from torch import optim
40+
41+
from tests_pytorch.helpers.runif import RunIf
4242

4343
if _OMEGACONF_AVAILABLE:
4444
from omegaconf import Container, OmegaConf
@@ -888,13 +888,11 @@ def test_default_checkpoint_behavior(tmp_path):
888888
assert len(results) == 1
889889
save_dir = tmp_path / "checkpoints"
890890
save_weights_only = trainer.checkpoint_callback.save_weights_only
891-
save_mock.assert_has_calls(
892-
[
893-
call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only),
894-
call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only),
895-
call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only),
896-
]
897-
)
891+
save_mock.assert_has_calls([
892+
call(str(save_dir / "epoch=0-step=5.ckpt"), save_weights_only),
893+
call(str(save_dir / "epoch=1-step=10.ckpt"), save_weights_only),
894+
call(str(save_dir / "epoch=2-step=15.ckpt"), save_weights_only),
895+
])
898896
ckpts = os.listdir(save_dir)
899897
assert len(ckpts) == 1
900898
assert ckpts[0] == "epoch=2-step=15.ckpt"
@@ -1478,8 +1476,6 @@ def test_save_last_versioning(tmp_path):
14781476
assert all(not os.path.islink(tmp_path / path) for path in set(os.listdir(tmp_path)))
14791477

14801478

1481-
1482-
14831479
def test_none_monitor_saves_correct_best_model_path(tmp_path):
14841480
mc = ModelCheckpoint(dirpath=tmp_path, monitor=None)
14851481
trainer = Trainer(callbacks=mc)

0 commit comments

Comments
 (0)