Skip to content

Commit f2b0db6

Browse files
rohitgr7ananthsubtchaton
authored
Raise a MisconfigurationException when trainer functions are called with ckpt_path="best" but checkpoint_callback isn't configured (#9841)
* add check * chlog * Apply suggestions from code review Co-authored-by: ananthsub <[email protected]> * Apply suggestions from code review Co-authored-by: thomas chaton <[email protected]> Co-authored-by: ananthsub <[email protected]> Co-authored-by: thomas chaton <[email protected]>
1 parent 64d1c46 commit f2b0db6

File tree

3 files changed

+62
-7
lines changed

3 files changed

+62
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
175175
- Enabled automatic parameters tying for TPUs ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))
176176

177177

178+
- Raise a `MisconfigurationException` when trainer functions are called with `ckpt_path="best"` but `checkpoint_callback` isn't configured ([#9841](https://github.com/PyTorchLightning/pytorch-lightning/pull/9841))
179+
180+
178181
- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))
179182

180183

pytorch_lightning/trainer/trainer.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,8 @@ def validate(
669669
670670
ckpt_path: Either ``best`` or path to the checkpoint you wish to validate.
671671
If ``None`` and the model instance was passed, use the current weights.
672-
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded.
672+
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
673+
if a checkpoint callback is configured.
673674
674675
verbose: If True, prints the validation results.
675676
@@ -758,7 +759,8 @@ def test(
758759
759760
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
760761
If ``None`` and the model instance was passed, use the current weights.
761-
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded.
762+
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
763+
if a checkpoint callback is configured.
762764
763765
verbose: If True, prints the test results.
764766
@@ -852,7 +854,8 @@ def predict(
852854
853855
ckpt_path: Either ``best`` or path to the checkpoint you wish to predict.
854856
If ``None`` and the model instance was passed, use the current weights.
855-
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded.
857+
Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded
858+
if a checkpoint callback is configured.
856859
857860
Returns:
858861
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
@@ -1281,15 +1284,20 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
12811284

12821285
if model_connected and ckpt_path is None:
12831286
rank_zero_warn(
1284-
f"`.{fn}(ckpt_path=None)` was called without a model. "
1285-
"The best model of the previous `fit` call will be used. "
1286-
f"You can pass `{fn}(ckpt_path='best')` to avoid this warning "
1287-
"or `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model."
1287+
f"`.{fn}(ckpt_path=None)` was called without a model."
1288+
" The best model of the previous `fit` call will be used."
1289+
f" You can pass `{fn}(ckpt_path='best')` to use and best model"
1290+
" checkpoint and avoid this warning or"
1291+
" `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model."
12881292
)
12891293
ckpt_path = "best"
12901294

12911295
if ckpt_path == "best":
12921296
# if user requests the best checkpoint but we don't have it, error
1297+
if not self.checkpoint_callback:
1298+
raise MisconfigurationException(
1299+
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.'
1300+
)
12931301
if not self.checkpoint_callback.best_model_path:
12941302
if self.fast_dev_run:
12951303
raise MisconfigurationException(

tests/trainer/test_trainer.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,50 @@ def predict_step(self, batch, *_):
785785
assert getattr(trainer, path_attr) == ckpt_path
786786

787787

788+
@pytest.mark.parametrize("checkpoint_callback", (False, True))
789+
@pytest.mark.parametrize("fn", ("validate", "test", "predict"))
790+
def test_tested_checkpoint_path_best(tmpdir, checkpoint_callback, fn):
791+
class TestModel(BoringModel):
792+
def validation_step(self, batch, batch_idx):
793+
self.log("foo", -batch_idx)
794+
return super().validation_step(batch, batch_idx)
795+
796+
def test_step(self, *args):
797+
return self.validation_step(*args)
798+
799+
def predict_step(self, batch, *_):
800+
return self(batch)
801+
802+
model = TestModel()
803+
model.test_epoch_end = None
804+
trainer = Trainer(
805+
max_epochs=2,
806+
limit_val_batches=1,
807+
limit_test_batches=1,
808+
limit_predict_batches=1,
809+
enable_progress_bar=False,
810+
default_root_dir=tmpdir,
811+
checkpoint_callback=checkpoint_callback,
812+
)
813+
trainer.fit(model)
814+
815+
trainer_fn = getattr(trainer, fn)
816+
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
817+
assert getattr(trainer, path_attr) is None
818+
819+
if checkpoint_callback:
820+
trainer_fn(ckpt_path="best")
821+
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
822+
823+
trainer_fn(model, ckpt_path="best")
824+
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
825+
else:
826+
with pytest.raises(MisconfigurationException, match="`ModelCheckpoint` is not configured."):
827+
trainer_fn(ckpt_path="best")
828+
with pytest.raises(MisconfigurationException, match="`ModelCheckpoint` is not configured."):
829+
trainer_fn(model, ckpt_path="best")
830+
831+
788832
def test_disabled_training(tmpdir):
789833
"""Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`."""
790834

0 commit comments

Comments
 (0)