Skip to content

Commit db322f4

Browse files
rohitgr7ananthsubpre-commit-ci[bot]
authored
Deprecate checkpoint_callback from the Trainer constructor in favour of enable_checkpointing (#9754)
* enable_chekpointing * update codebase * chlog * update tests * fix warning * Apply suggestions from code review Co-authored-by: ananthsub <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: ananthsub <[email protected]> * Apply suggestions from code review Co-authored-by: ananthsub <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 14fb076 commit db322f4

33 files changed

+130
-109
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
322322
- Deprecated Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call `TrainingTypePlugin` collective API directly ([#9677](https://github.com/PyTorchLightning/pytorch-lightning/pull/9677))
323323

324324

325+
- Deprecated `checkpoint_callback` from the `Trainer` constructor in favour of `enable_checkpointing` ([#9754](https://github.com/PyTorchLightning/pytorch-lightning/pull/9754))
326+
327+
325328
- Deprecated the `LightningModule.on_post_move_to_device` method ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))
326329

327330

docs/source/common/hyperparameters.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ To recap, add ALL possible trainer flags to the argparser and init the ``Trainer
201201
trainer = Trainer.from_argparse_args(hparams)
202202
203203
# or if you need to pass in callbacks
204-
trainer = Trainer.from_argparse_args(hparams, checkpoint_callback=..., callbacks=[...])
204+
trainer = Trainer.from_argparse_args(hparams, enable_checkpointing=..., callbacks=[...])
205205
206206
----------
207207

docs/source/common/trainer.rst

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,38 @@ Example::
528528
checkpoint_callback
529529
^^^^^^^^^^^^^^^^^^^
530530

531+
Deprecated: This has been deprecated in v1.5 and will be removed in v.17. Please use ``enable_checkpointing`` instead.
532+
533+
default_root_dir
534+
^^^^^^^^^^^^^^^^
535+
536+
.. raw:: html
537+
538+
<video width="50%" max-width="400px" controls
539+
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/default%E2%80%A8_root_dir.jpg"
540+
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/default_root_dir.mp4"></video>
541+
542+
|
543+
544+
Default path for logs and weights when no logger or
545+
:class:`pytorch_lightning.callbacks.ModelCheckpoint` callback passed. On
546+
certain clusters you might want to separate where logs and checkpoints are
547+
stored. If you don't then use this argument for convenience. Paths can be local
548+
paths or remote paths such as `s3://bucket/path` or 'hdfs://path/'. Credentials
549+
will need to be set up to use remote filepaths.
550+
551+
.. testcode::
552+
553+
# default used by the Trainer
554+
trainer = Trainer(default_root_dir=os.getcwd())
555+
556+
distributed_backend
557+
^^^^^^^^^^^^^^^^^^^
558+
Deprecated: This has been renamed ``accelerator``.
559+
560+
enable_checkpointing
561+
^^^^^^^^^^^^^^^^^^^^
562+
531563
.. raw:: html
532564

533565
<video width="50%" max-width="400px" controls
@@ -542,11 +574,11 @@ To disable automatic checkpointing, set this to `False`.
542574

543575
.. code-block:: python
544576
545-
# default used by Trainer
546-
trainer = Trainer(checkpoint_callback=True)
577+
# default used by Trainer, saves the most recent model to a single checkpoint after each epoch
578+
trainer = Trainer(enable_checkpointing=True)
547579
548580
# turn off automatic checkpointing
549-
trainer = Trainer(checkpoint_callback=False)
581+
trainer = Trainer(enable_checkpointing=False)
550582
551583
552584
You can override the default behavior by initializing the :class:`~pytorch_lightning.callbacks.ModelCheckpoint`
@@ -563,38 +595,6 @@ See :doc:`Saving and Loading Weights <../common/weights_loading>` for how to cus
563595
# Add your callback to the callbacks list
564596
trainer = Trainer(callbacks=[checkpoint_callback])
565597

566-
567-
.. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since
568-
v1.1 and will be unsupported from v1.3. Use `callbacks` argument instead.
569-
570-
571-
default_root_dir
572-
^^^^^^^^^^^^^^^^
573-
574-
.. raw:: html
575-
576-
<video width="50%" max-width="400px" controls
577-
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/default%E2%80%A8_root_dir.jpg"
578-
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/default_root_dir.mp4"></video>
579-
580-
|
581-
582-
Default path for logs and weights when no logger or
583-
:class:`pytorch_lightning.callbacks.ModelCheckpoint` callback passed. On
584-
certain clusters you might want to separate where logs and checkpoints are
585-
stored. If you don't then use this argument for convenience. Paths can be local
586-
paths or remote paths such as `s3://bucket/path` or 'hdfs://path/'. Credentials
587-
will need to be set up to use remote filepaths.
588-
589-
.. testcode::
590-
591-
# default used by the Trainer
592-
trainer = Trainer(default_root_dir=os.getcwd())
593-
594-
distributed_backend
595-
^^^^^^^^^^^^^^^^^^^
596-
Deprecated: This has been renamed ``accelerator``.
597-
598598
fast_dev_run
599599
^^^^^^^^^^^^
600600

pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def __init__(self, trainer):
3838
def on_trainer_init(
3939
self,
4040
callbacks: Optional[Union[List[Callback], Callback]],
41-
checkpoint_callback: bool,
41+
checkpoint_callback: Optional[bool],
42+
enable_checkpointing: bool,
4243
enable_progress_bar: bool,
4344
progress_bar_refresh_rate: Optional[int],
4445
process_position: int,
@@ -67,7 +68,7 @@ def on_trainer_init(
6768

6869
# configure checkpoint callback
6970
# pass through the required args to figure out defaults
70-
self._configure_checkpoint_callbacks(checkpoint_callback)
71+
self._configure_checkpoint_callbacks(checkpoint_callback, enable_checkpointing)
7172

7273
# configure swa callback
7374
self._configure_swa_callbacks()
@@ -140,22 +141,31 @@ def _configure_accumulated_gradients(
140141
self.trainer.accumulate_grad_batches = grad_accum_callback.get_accumulate_grad_batches(0)
141142
self.trainer.accumulation_scheduler = grad_accum_callback
142143

143-
def _configure_checkpoint_callbacks(self, checkpoint_callback: bool) -> None:
144+
def _configure_checkpoint_callbacks(self, checkpoint_callback: Optional[bool], enable_checkpointing: bool) -> None:
145+
if checkpoint_callback is not None:
146+
rank_zero_deprecation(
147+
f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
148+
f"be removed in v1.7. Please consider using `Trainer(enable_checkpointing={checkpoint_callback})`."
149+
)
150+
# if both are set then checkpoint only if both are True
151+
enable_checkpointing = checkpoint_callback and enable_checkpointing
152+
144153
# TODO: Remove this error in v1.5 so we rely purely on the type signature
145-
if not isinstance(checkpoint_callback, bool):
154+
if not isinstance(enable_checkpointing, bool):
146155
error_msg = (
147-
"Invalid type provided for checkpoint_callback:"
148-
f" Expected bool but received {type(checkpoint_callback)}."
156+
"Invalid type provided for `enable_checkpointing`: "
157+
f"Expected bool but received {type(enable_checkpointing)}."
149158
)
150-
if isinstance(checkpoint_callback, Callback):
159+
if isinstance(enable_checkpointing, Callback):
151160
error_msg += " Pass callback instances to the `callbacks` argument in the Trainer constructor instead."
152161
raise MisconfigurationException(error_msg)
153-
if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False:
162+
if self._trainer_has_checkpoint_callbacks() and enable_checkpointing is False:
154163
raise MisconfigurationException(
155-
"Trainer was configured with checkpoint_callback=False but found ModelCheckpoint in callbacks list."
164+
"Trainer was configured with `enable_checkpointing=False`"
165+
" but found `ModelCheckpoint` in callbacks list."
156166
)
157167

158-
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
168+
if not self._trainer_has_checkpoint_callbacks() and enable_checkpointing is True:
159169
self.trainer.callbacks.append(ModelCheckpoint())
160170

161171
def _configure_model_summary_callback(self, weights_summary: Optional[str] = None) -> None:

pytorch_lightning/trainer/trainer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ class Trainer(
120120
def __init__(
121121
self,
122122
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
123-
checkpoint_callback: bool = True,
123+
checkpoint_callback: Optional[bool] = None,
124+
enable_checkpointing: bool = True,
124125
callbacks: Optional[Union[List[Callback], Callback]] = None,
125126
default_root_dir: Optional[str] = None,
126127
gradient_clip_val: Union[int, float] = 0.0,
@@ -215,6 +216,12 @@ def __init__(
215216
callbacks: Add a callback or list of callbacks.
216217
217218
checkpoint_callback: If ``True``, enable checkpointing.
219+
220+
.. deprecated:: v1.5
221+
``checkpoint_callback`` has been deprecated in v1.5 and will be removed in v1.7.
222+
Please consider using ``enable_checkpointing`` instead.
223+
224+
enable_checkpointing: If ``True``, enable checkpointing.
218225
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
219226
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`.
220227
@@ -465,6 +472,7 @@ def __init__(
465472
self.callback_connector.on_trainer_init(
466473
callbacks,
467474
checkpoint_callback,
475+
enable_checkpointing,
468476
enable_progress_bar,
469477
progress_bar_refresh_rate,
470478
process_position,

tests/accelerators/test_tpu_backend.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_resume_training_on_cpu(tmpdir):
5151
"""Checks if training can be resumed from a saved checkpoint on CPU."""
5252
# Train a model on TPU
5353
model = BoringModel()
54-
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=8)
54+
trainer = Trainer(max_epochs=1, tpu_cores=8)
5555
trainer.fit(model)
5656

5757
model_path = trainer.checkpoint_callback.best_model_path
@@ -62,9 +62,7 @@ def test_resume_training_on_cpu(tmpdir):
6262
assert weight_tensor.device == torch.device("cpu")
6363

6464
# Verify that training is resumed on CPU
65-
trainer = Trainer(
66-
resume_from_checkpoint=model_path, checkpoint_callback=True, max_epochs=1, default_root_dir=tmpdir
67-
)
65+
trainer = Trainer(resume_from_checkpoint=model_path, max_epochs=1, default_root_dir=tmpdir)
6866
trainer.fit(model)
6967
assert trainer.state.finished, f"Training failed with {trainer.state}"
7068

tests/callbacks/test_callbacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def configure_callbacks(self):
3535

3636
model = TestModel()
3737
trainer_options = dict(
38-
default_root_dir=tmpdir, checkpoint_callback=False, fast_dev_run=True, enable_progress_bar=False
38+
default_root_dir=tmpdir, enable_checkpointing=False, fast_dev_run=True, enable_progress_bar=False
3939
)
4040

4141
def assert_expected_calls(_trainer, model_callback, trainer_callback):
@@ -86,7 +86,7 @@ def configure_callbacks(self):
8686
return [model_callback_mock]
8787

8888
model = TestModel()
89-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
89+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, enable_checkpointing=False)
9090

9191
callbacks_before_fit = trainer.callbacks.copy()
9292
assert callbacks_before_fit

tests/callbacks/test_early_stopping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_early_stopping_no_extraneous_invocations(tmpdir):
111111
limit_train_batches=4,
112112
limit_val_batches=4,
113113
max_epochs=expected_count,
114-
checkpoint_callback=False,
114+
enable_checkpointing=False,
115115
)
116116
trainer.fit(model, datamodule=dm)
117117

tests/callbacks/test_lr_monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
390390
callbacks=[TestFinetuning(), lr_monitor, Check()],
391391
enable_progress_bar=False,
392392
weights_summary=None,
393-
checkpoint_callback=False,
393+
enable_checkpointing=False,
394394
)
395395
model = TestModel()
396396
model.training_epoch_end = None

tests/callbacks/test_progress_bar.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def on_validation_epoch_end(self, *args):
263263
limit_val_batches=limit_val_batches,
264264
callbacks=[progress_bar],
265265
logger=False,
266-
checkpoint_callback=False,
266+
enable_checkpointing=False,
267267
)
268268
trainer.fit(model)
269269

@@ -342,7 +342,7 @@ def test_main_progress_bar_update_amount(
342342
limit_val_batches=val_batches,
343343
callbacks=[progress_bar],
344344
logger=False,
345-
checkpoint_callback=False,
345+
enable_checkpointing=False,
346346
)
347347
trainer.fit(model)
348348
if train_batches > 0:
@@ -362,7 +362,7 @@ def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate
362362
limit_test_batches=test_batches,
363363
callbacks=[progress_bar],
364364
logger=False,
365-
checkpoint_callback=False,
365+
enable_checkpointing=False,
366366
)
367367
trainer.test(model)
368368
progress_bar.test_progress_bar.update.assert_has_calls([call(delta) for delta in test_deltas])
@@ -379,7 +379,7 @@ def training_step(self, batch, batch_idx):
379379
return super().training_step(batch, batch_idx)
380380

381381
trainer = Trainer(
382-
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, logger=False, checkpoint_callback=False
382+
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False
383383
)
384384
trainer.fit(TestModel())
385385

0 commit comments

Comments
 (0)