Skip to content

Commit c110f4f

Browse files
chualanagitAlan Chupre-commit-ci[bot]lantiga
authored
Allow callbacks to be restored not just during training (#20403)
* Allow callbacks to be restored not just during training * add test case * test test case failure * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test case --------- Co-authored-by: Alan Chu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Antiga <[email protected]>
1 parent cd2bd3c commit c110f4f

File tree

2 files changed

+52
-4
lines changed

2 files changed

+52
-4
lines changed

src/lightning/pytorch/trainer/connectors/checkpoint_connector.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,7 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None
397397
self.resume_start(checkpoint_path)
398398
self.restore_model()
399399
self.restore_datamodule()
400-
if self.trainer.state.fn == TrainerFn.FITTING:
401-
# restore callback states
402-
self.restore_callbacks()
400+
self.restore_callbacks()
403401

404402
def dump_checkpoint(self, weights_only: bool = False) -> dict:
405403
"""Creating a model checkpoint dictionary object from various component states.

tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pytest
1919
import torch
2020
from lightning.pytorch import Trainer
21-
from lightning.pytorch.callbacks import ModelCheckpoint
21+
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
2222
from lightning.pytorch.demos.boring_classes import BoringModel
2323
from lightning.pytorch.trainer.states import TrainerFn
2424
from lightning.pytorch.utilities.migration.utils import _set_version
@@ -234,3 +234,53 @@ def test_strict_loading(strict_loading, expected, tmp_path):
234234
trainer = Trainer(default_root_dir=tmp_path, barebones=True, max_steps=2)
235235
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
236236
model.load_state_dict.assert_called_once_with(ANY, strict=expected)
237+
238+
239+
@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"])
240+
def test_restore_callbacks_in_non_fit_phases(tmp_path, trainer_fn):
241+
"""Test that callbacks are properly restored in non-fit phases."""
242+
243+
class TestCallback(Callback):
244+
def __init__(self):
245+
self.restored = False
246+
247+
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
248+
if "callbacks" in checkpoint:
249+
callback_state = checkpoint["callbacks"][self.__class__.__name__]
250+
self.restored = callback_state["restored"]
251+
252+
def state_dict(self):
253+
return {"restored": self.restored}
254+
255+
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
256+
checkpoint["callbacks"] = checkpoint.get("callbacks", {})
257+
checkpoint["callbacks"][self.__class__.__name__] = self.state_dict()
258+
259+
# First create and train a model with the callback
260+
callback = TestCallback()
261+
model = BoringModel()
262+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[callback], max_steps=1)
263+
trainer.fit(model)
264+
265+
# Set the callback state to True before saving
266+
callback.restored = True
267+
ckpt_path = tmp_path / "checkpoint.ckpt"
268+
trainer.save_checkpoint(ckpt_path)
269+
270+
# Now create new instances and test restoration
271+
new_callback = TestCallback()
272+
new_model = BoringModel()
273+
assert not new_callback.restored # Should start False
274+
275+
new_trainer = Trainer(default_root_dir=tmp_path, callbacks=[new_callback])
276+
277+
# Connect the model and restore callbacks before evaluation
278+
new_trainer.strategy.connect(new_model)
279+
new_trainer._checkpoint_connector.resume_start(ckpt_path)
280+
new_trainer._checkpoint_connector.restore_callbacks()
281+
282+
# Run the evaluation phase (validate/test/predict)
283+
fn = getattr(new_trainer, trainer_fn)
284+
fn(new_model, ckpt_path=ckpt_path)
285+
286+
assert new_callback.restored # Should be True after loading the checkpoint

0 commit comments

Comments
 (0)