@@ -245,35 +245,42 @@ def __init__(self):
245245 self .restored = False
246246
247247 def on_load_checkpoint (self , trainer , pl_module , checkpoint ):
248- # This should be called when checkpoint is loaded
249- self .restored = True
248+ if "callbacks" in checkpoint :
249+ callback_state = checkpoint ["callbacks" ][self .__class__ .__name__ ]
250+ self .restored = callback_state ["restored" ]
250251
251252 def state_dict (self ):
252253 return {"restored" : self .restored }
253254
254- def load_state_dict (self , state_dict ):
255- self .restored = state_dict ["restored" ]
255+ def on_save_checkpoint (self , trainer , pl_module , checkpoint ):
256+ checkpoint ["callbacks" ] = checkpoint .get ("callbacks" , {})
257+ checkpoint ["callbacks" ][self .__class__ .__name__ ] = self .state_dict ()
256258
257- # Initial training to create checkpoint
259+ # First create and train a model with the callback
258260 callback = TestCallback ()
259261 model = BoringModel ()
260- trainer = Trainer (
261- default_root_dir = tmp_path ,
262- max_steps = 1 ,
263- callbacks = [callback ],
264- enable_checkpointing = True ,
265- )
262+ trainer = Trainer (default_root_dir = tmp_path , callbacks = [callback ], max_steps = 1 )
266263 trainer .fit (model )
267- ckpt_path = trainer .checkpoint_callback .best_model_path
268- assert os .path .exists (ckpt_path )
269264
270- # Test restoration in different phases
265+ # Set the callback state to True before saving
266+ callback .restored = True
267+ ckpt_path = tmp_path / "checkpoint.ckpt"
268+ trainer .save_checkpoint (ckpt_path )
269+
270+ # Now create new instances and test restoration
271271 new_callback = TestCallback ()
272+ new_model = BoringModel ()
272273 assert not new_callback .restored # Should start False
273274
274- new_trainer = Trainer (callbacks = [new_callback ])
275+ new_trainer = Trainer (default_root_dir = tmp_path , callbacks = [new_callback ])
276+
277+ # Connect the model and restore callbacks before evaluation
278+ new_trainer .strategy .connect (new_model )
279+ new_trainer ._checkpoint_connector .resume_start (ckpt_path )
280+ new_trainer ._checkpoint_connector .restore_callbacks ()
281+
282+ # Run the evaluation phase (validate/test/predict)
275283 fn = getattr (new_trainer , trainer_fn )
276- fn (model , ckpt_path = ckpt_path )
284+ fn (new_model , ckpt_path = ckpt_path )
277285
278- # Verify callback restoration was triggered
279- assert not new_callback .restored # Should be True if restore_callbacks() was called
286+ assert new_callback .restored # Should be True after loading the checkpoint
0 commit comments