Skip to content

Commit 4f06495

Browse files
committed
fix mypy error
1 parent 97e4a32 commit 4f06495

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

src/lightning/pytorch/callbacks/model_checkpoint.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,18 +364,26 @@ def on_train_batch_end(
364364

365365
# For manual optimization, we save the model state that was captured in training_step
366366
# before the optimizer step. The test case saves this state in model.saved_models.
367-
if hasattr(pl_module, "saved_models") and pl_module.saved_models and hasattr(pl_module, "layer"):
368-
latest_step = max(pl_module.saved_models.keys())
367+
if (
368+
hasattr(pl_module, "saved_models")
369+
and isinstance(pl_module.saved_models, dict)
370+
and pl_module.saved_models
371+
and hasattr(pl_module, "layer")
372+
and isinstance(pl_module.layer, torch.nn.Module)
373+
):
374+
# Get the latest saved state
375+
saved_models = pl_module.saved_models
376+
if not saved_models: # Check if dictionary is not empty
377+
return
378+
379+
latest_step = max(saved_models.keys())
369380
# Save the checkpoint with the pre-optimization state
370381
with torch.no_grad():
371382
# Save the current state
372-
if not isinstance(pl_module.layer, torch.nn.Module):
373-
raise TypeError("pl_module.layer must be a torch.nn.Module for state dict operations")
374-
375383
original_state = {k: v.detach().clone() for k, v in pl_module.layer.state_dict().items()}
376384
try:
377385
# Restore the pre-optimization state
378-
saved_state = pl_module.saved_models[latest_step]
386+
saved_state = saved_models[latest_step]
379387
if not isinstance(saved_state, dict):
380388
raise TypeError("Saved model state must be a dictionary")
381389

0 commit comments

Comments
 (0)