Skip to content

Commit a43eb5d

Browse files
Alan ChuAlan Chu
authored andcommitted
add test case
1 parent 1f80567 commit a43eb5d

File tree

1 file changed

+44
-1
lines changed

1 file changed

+44
-1
lines changed

tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pytest
1919
import torch
2020
from lightning.pytorch import Trainer
21-
from lightning.pytorch.callbacks import ModelCheckpoint
21+
from lightning.pytorch.callbacks import ModelCheckpoint, Callback
2222
from lightning.pytorch.demos.boring_classes import BoringModel
2323
from lightning.pytorch.trainer.states import TrainerFn
2424
from lightning.pytorch.utilities.migration.utils import _set_version
@@ -234,3 +234,46 @@ def test_strict_loading(strict_loading, expected, tmp_path):
234234
trainer = Trainer(default_root_dir=tmp_path, barebones=True, max_steps=2)
235235
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
236236
model.load_state_dict.assert_called_once_with(ANY, strict=expected)
237+
238+
239+
@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"])
240+
def test_restore_callbacks_in_non_fit_phases(tmp_path, trainer_fn):
241+
"""Test that callbacks are properly restored in non-fit phases."""
242+
243+
class TestCallback(Callback):
244+
def __init__(self):
245+
self.restored = False
246+
247+
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
248+
# This should be called when checkpoint is loaded
249+
self.restored = True
250+
251+
def state_dict(self):
252+
return {"restored": self.restored}
253+
254+
def load_state_dict(self, state_dict):
255+
self.restored = state_dict["restored"]
256+
257+
# Initial training to create checkpoint
258+
callback = TestCallback()
259+
model = BoringModel()
260+
trainer = Trainer(
261+
default_root_dir=tmp_path,
262+
max_steps=1,
263+
callbacks=[callback],
264+
enable_checkpointing=True,
265+
)
266+
trainer.fit(model)
267+
ckpt_path = trainer.checkpoint_callback.best_model_path
268+
assert os.path.exists(ckpt_path)
269+
270+
# Test restoration in different phases
271+
new_callback = TestCallback()
272+
assert not new_callback.restored # Should start False
273+
274+
new_trainer = Trainer(callbacks=[new_callback])
275+
fn = getattr(new_trainer, trainer_fn)
276+
fn(model, ckpt_path=ckpt_path)
277+
278+
# Verify callback restoration was triggered
279+
assert new_callback.restored # Should be True if restore_callbacks() was called

0 commit comments

Comments
 (0)