|
18 | 18 | import pytest
|
19 | 19 | import torch
|
20 | 20 | from lightning.pytorch import Trainer
|
21 |
| -from lightning.pytorch.callbacks import ModelCheckpoint |
| 21 | +from lightning.pytorch.callbacks import Callback, ModelCheckpoint |
22 | 22 | from lightning.pytorch.demos.boring_classes import BoringModel
|
23 | 23 | from lightning.pytorch.trainer.states import TrainerFn
|
24 | 24 | from lightning.pytorch.utilities.migration.utils import _set_version
|
@@ -234,3 +234,53 @@ def test_strict_loading(strict_loading, expected, tmp_path):
|
234 | 234 | trainer = Trainer(default_root_dir=tmp_path, barebones=True, max_steps=2)
|
235 | 235 | trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
|
236 | 236 | model.load_state_dict.assert_called_once_with(ANY, strict=expected)
|
| 237 | + |
| 238 | + |
| 239 | +@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"]) |
| 240 | +def test_restore_callbacks_in_non_fit_phases(tmp_path, trainer_fn): |
| 241 | + """Test that callbacks are properly restored in non-fit phases.""" |
| 242 | + |
| 243 | + class TestCallback(Callback): |
| 244 | + def __init__(self): |
| 245 | + self.restored = False |
| 246 | + |
| 247 | + def on_load_checkpoint(self, trainer, pl_module, checkpoint): |
| 248 | + if "callbacks" in checkpoint: |
| 249 | + callback_state = checkpoint["callbacks"][self.__class__.__name__] |
| 250 | + self.restored = callback_state["restored"] |
| 251 | + |
| 252 | + def state_dict(self): |
| 253 | + return {"restored": self.restored} |
| 254 | + |
| 255 | + def on_save_checkpoint(self, trainer, pl_module, checkpoint): |
| 256 | + checkpoint["callbacks"] = checkpoint.get("callbacks", {}) |
| 257 | + checkpoint["callbacks"][self.__class__.__name__] = self.state_dict() |
| 258 | + |
| 259 | + # First create and train a model with the callback |
| 260 | + callback = TestCallback() |
| 261 | + model = BoringModel() |
| 262 | + trainer = Trainer(default_root_dir=tmp_path, callbacks=[callback], max_steps=1) |
| 263 | + trainer.fit(model) |
| 264 | + |
| 265 | + # Set the callback state to True before saving |
| 266 | + callback.restored = True |
| 267 | + ckpt_path = tmp_path / "checkpoint.ckpt" |
| 268 | + trainer.save_checkpoint(ckpt_path) |
| 269 | + |
| 270 | + # Now create new instances and test restoration |
| 271 | + new_callback = TestCallback() |
| 272 | + new_model = BoringModel() |
| 273 | + assert not new_callback.restored # Should start False |
| 274 | + |
| 275 | + new_trainer = Trainer(default_root_dir=tmp_path, callbacks=[new_callback]) |
| 276 | + |
| 277 | + # Connect the model and restore callbacks before evaluation |
| 278 | + new_trainer.strategy.connect(new_model) |
| 279 | + new_trainer._checkpoint_connector.resume_start(ckpt_path) |
| 280 | + new_trainer._checkpoint_connector.restore_callbacks() |
| 281 | + |
| 282 | + # Run the evaluation phase (validate/test/predict) |
| 283 | + fn = getattr(new_trainer, trainer_fn) |
| 284 | + fn(new_model, ckpt_path=ckpt_path) |
| 285 | + |
| 286 | + assert new_callback.restored # Should be True after loading the checkpoint |
0 commit comments