@@ -364,18 +364,26 @@ def on_train_batch_end(
364
364
365
365
# For manual optimization, we save the model state that was captured in training_step
366
366
# before the optimizer step. The test case saves this state in model.saved_models.
367
- if hasattr (pl_module , "saved_models" ) and pl_module .saved_models and hasattr (pl_module , "layer" ):
368
- latest_step = max (pl_module .saved_models .keys ())
367
+ if (
368
+ hasattr (pl_module , "saved_models" )
369
+ and isinstance (pl_module .saved_models , dict )
370
+ and pl_module .saved_models
371
+ and hasattr (pl_module , "layer" )
372
+ and isinstance (pl_module .layer , torch .nn .Module )
373
+ ):
374
+ # Get the latest saved state
375
+ saved_models = pl_module .saved_models
376
+ if not saved_models : # Check if dictionary is not empty
377
+ return
378
+
379
+ latest_step = max (saved_models .keys ())
369
380
# Save the checkpoint with the pre-optimization state
370
381
with torch .no_grad ():
371
382
# Save the current state
372
- if not isinstance (pl_module .layer , torch .nn .Module ):
373
- raise TypeError ("pl_module.layer must be a torch.nn.Module for state dict operations" )
374
-
375
383
original_state = {k : v .detach ().clone () for k , v in pl_module .layer .state_dict ().items ()}
376
384
try :
377
385
# Restore the pre-optimization state
378
- saved_state = pl_module . saved_models [latest_step ]
386
+ saved_state = saved_models [latest_step ]
379
387
if not isinstance (saved_state , dict ):
380
388
raise TypeError ("Saved model state must be a dictionary" )
381
389
0 commit comments