Skip to content

Commit e1280df

Browse files
Alan ChuAlan Chu
authored andcommitted
fix test case
1 parent 094a095 commit e1280df

File tree

1 file changed

+25
-18
lines changed

1 file changed

+25
-18
lines changed

tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -245,35 +245,42 @@ def __init__(self):
245245
self.restored = False
246246

247247
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
248-
# This should be called when checkpoint is loaded
249-
self.restored = True
248+
if "callbacks" in checkpoint:
249+
callback_state = checkpoint["callbacks"][self.__class__.__name__]
250+
self.restored = callback_state["restored"]
250251

251252
def state_dict(self):
252253
return {"restored": self.restored}
253254

254-
def load_state_dict(self, state_dict):
255-
self.restored = state_dict["restored"]
255+
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
256+
checkpoint["callbacks"] = checkpoint.get("callbacks", {})
257+
checkpoint["callbacks"][self.__class__.__name__] = self.state_dict()
256258

257-
# Initial training to create checkpoint
259+
# First create and train a model with the callback
258260
callback = TestCallback()
259261
model = BoringModel()
260-
trainer = Trainer(
261-
default_root_dir=tmp_path,
262-
max_steps=1,
263-
callbacks=[callback],
264-
enable_checkpointing=True,
265-
)
262+
trainer = Trainer(default_root_dir=tmp_path, callbacks=[callback], max_steps=1)
266263
trainer.fit(model)
267-
ckpt_path = trainer.checkpoint_callback.best_model_path
268-
assert os.path.exists(ckpt_path)
269264

270-
# Test restoration in different phases
265+
# Set the callback state to True before saving
266+
callback.restored = True
267+
ckpt_path = tmp_path / "checkpoint.ckpt"
268+
trainer.save_checkpoint(ckpt_path)
269+
270+
# Now create new instances and test restoration
271271
new_callback = TestCallback()
272+
new_model = BoringModel()
272273
assert not new_callback.restored # Should start False
273274

274-
new_trainer = Trainer(callbacks=[new_callback])
275+
new_trainer = Trainer(default_root_dir=tmp_path, callbacks=[new_callback])
276+
277+
# Connect the model and restore callbacks before evaluation
278+
new_trainer.strategy.connect(new_model)
279+
new_trainer._checkpoint_connector.resume_start(ckpt_path)
280+
new_trainer._checkpoint_connector.restore_callbacks()
281+
282+
# Run the evaluation phase (validate/test/predict)
275283
fn = getattr(new_trainer, trainer_fn)
276-
fn(model, ckpt_path=ckpt_path)
284+
fn(new_model, ckpt_path=ckpt_path)
277285

278-
# Verify callback restoration was triggered
279-
assert not new_callback.restored # Should be True if restore_callbacks() was called
286+
assert new_callback.restored # Should be True after loading the checkpoint

0 commit comments

Comments
 (0)