Skip to content

Commit 334675e

Browse files
authored
Deprecate ModelCheckpoint.save_checkpoint (#12456)
1 parent fe12bae commit 334675e

File tree

4 files changed

+24
-7
lines changed

4 files changed

+24
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
582582
- Deprecated `ParallelPlugin.torch_distributed_backend` in favor of `DDPStrategy.process_group_backend` property ([#11745](https://github.com/PyTorchLightning/pytorch-lightning/pull/11745))
583583

584584

585+
- Deprecated `ModelCheckpoint.save_checkpoint` in favor of `Trainer.save_checkpoint` ([#12456](https://github.com/PyTorchLightning/pytorch-lightning/pull/12456))
586+
587+
585588
- Deprecated `Trainer.devices` in favor of `Trainer.num_devices` and `Trainer.device_ids` ([#12151](https://github.com/PyTorchLightning/pytorch-lightning/pull/12151))
586589

587590

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from pytorch_lightning.utilities.cloud_io import get_filesystem
3838
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3939
from pytorch_lightning.utilities.logger import _name, _version
40-
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
40+
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
4141
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
4242
from pytorch_lightning.utilities.warnings import WarningCache
4343

@@ -353,7 +353,10 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover
353353
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
354354
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
355355
"""
356-
# TODO: unused method. deprecate it
356+
rank_zero_deprecation(
357+
f"`{self.__class__.__name__}.save_checkpoint()` was deprecated in v1.6 and will be removed in v1.8."
358+
" Instead, you can use `trainer.save_checkpoint()` to manually save a checkpoint."
359+
)
357360
monitor_candidates = self._monitor_candidates(trainer)
358361
self._save_topk_checkpoint(trainer, monitor_candidates)
359362
self._save_last_checkpoint(trainer, monitor_candidates)

tests/checkpointing/test_model_checkpoint.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pathlib import Path
2323
from typing import Union
2424
from unittest import mock
25-
from unittest.mock import call, MagicMock, Mock, patch
25+
from unittest.mock import call, Mock, patch
2626

2727
import cloudpickle
2828
import pytest
@@ -834,7 +834,7 @@ def validation_epoch_end(self, outputs):
834834
val_check_interval=1.0,
835835
max_epochs=len(monitor),
836836
)
837-
trainer.save_checkpoint = MagicMock()
837+
trainer.save_checkpoint = Mock()
838838

839839
trainer.fit(model)
840840

@@ -1309,9 +1309,10 @@ def test_none_monitor_saves_correct_best_model_path(tmpdir):
13091309
def test_last_global_step_saved():
13101310
# this should not save anything
13111311
model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo")
1312-
trainer = MagicMock()
1313-
trainer.callback_metrics = {"foo": 123}
1314-
model_checkpoint.save_checkpoint(trainer)
1312+
trainer = Mock()
1313+
monitor_candidates = {"foo": 123}
1314+
model_checkpoint._save_topk_checkpoint(trainer, monitor_candidates)
1315+
model_checkpoint._save_last_checkpoint(trainer, monitor_candidates)
13151316
assert model_checkpoint._last_global_step_saved == 0
13161317

13171318

tests/deprecated_api/test_remove_1-8.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import pytorch_lightning
2626
from pytorch_lightning import Callback, Trainer
27+
from pytorch_lightning.callbacks import ModelCheckpoint
2728
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase, LoggerCollection
2829
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
2930
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
@@ -1055,6 +1056,15 @@ def test_trainer_data_parallel_device_ids(monkeypatch, trainer_kwargs, expected_
10551056
assert trainer.data_parallel_device_ids == expected_data_parallel_device_ids
10561057

10571058

1059+
def test_deprecated_mc_save_checkpoint():
1060+
mc = ModelCheckpoint()
1061+
trainer = Trainer()
1062+
with mock.patch.object(trainer, "save_checkpoint"), pytest.deprecated_call(
1063+
match=r"ModelCheckpoint.save_checkpoint\(\)` was deprecated in v1.6"
1064+
):
1065+
mc.save_checkpoint(trainer)
1066+
1067+
10581068
def test_v1_8_0_callback_on_load_checkpoint_hook(tmpdir):
10591069
class TestCallbackLoadHook(Callback):
10601070
def on_load_checkpoint(self, trainer, pl_module, callback_state):

0 commit comments

Comments
 (0)