@@ -730,6 +730,71 @@ def test_restart_parity(tmp_path):
730730 assert compare_state_dicts (end_of_epoch_ckpt ["state_dict" ], end_of_epoch_ckpt_v1 ["state_dict" ]) == {}
731731
732732
733+ def test_restart_parity_with_val (tmp_path ):
734+ model = PredictableBoringModel ()
735+ checkpoint_callback = ModelCheckpoint (
736+ dirpath = tmp_path ,
737+ every_n_train_steps = 2 ,
738+ save_top_k = - 1 ,
739+ )
740+ trainer = Trainer (
741+ default_root_dir = tmp_path ,
742+ limit_train_batches = 4 ,
743+ max_epochs = 4 ,
744+ callbacks = [checkpoint_callback ],
745+ logger = False ,
746+ enable_model_summary = False ,
747+ enable_progress_bar = False ,
748+ limit_val_batches = 4 ,
749+ val_check_interval = 2 ,
750+ )
751+ trainer .fit (model )
752+ loss = model .last_loss
753+
754+ trainer = Trainer (
755+ default_root_dir = tmp_path ,
756+ limit_train_batches = 4 ,
757+ max_epochs = 4 ,
758+ callbacks = [checkpoint_callback ],
759+ logger = False ,
760+ enable_model_summary = False ,
761+ enable_progress_bar = False ,
762+ limit_val_batches = 4 ,
763+ val_check_interval = 2 ,
764+ )
765+ trainer .fit (model , ckpt_path = str (tmp_path / "epoch=0-step=2.ckpt" ))
766+ loss_v1 = model .last_loss
767+
768+ assert (abs (loss - loss_v1 ) < 1e-8 )
769+
770+ end_of_epoch_ckpt = torch .load (str (tmp_path / "epoch=0-step=4.ckpt" ), weights_only = True )
771+ end_of_epoch_ckpt_v1 = torch .load (str (tmp_path / "epoch=0-step=4-v1.ckpt" ), weights_only = True )
772+
773+ assert compare_state_dicts (end_of_epoch_ckpt ["loops" ], end_of_epoch_ckpt_v1 ["loops" ]) == {}
774+ assert compare_state_dicts (end_of_epoch_ckpt ["lr_schedulers" ][0 ], end_of_epoch_ckpt_v1 ["lr_schedulers" ][0 ]) == {}
775+ assert end_of_epoch_ckpt ["epoch" ] == end_of_epoch_ckpt_v1 ["epoch" ]
776+ assert end_of_epoch_ckpt ["global_step" ] == end_of_epoch_ckpt_v1 ["global_step" ]
777+ assert compare_state_dicts (end_of_epoch_ckpt ["state_dict" ], end_of_epoch_ckpt_v1 ["state_dict" ]) == {}
778+
779+ mid_epoch_ckpt = torch .load (str (tmp_path / "epoch=1-step=6.ckpt" ), weights_only = True )
780+ mid_epoch_ckpt_v1 = torch .load (str (tmp_path / "epoch=1-step=6-v1.ckpt" ), weights_only = True )
781+
782+ assert compare_state_dicts (mid_epoch_ckpt ["loops" ], mid_epoch_ckpt_v1 ["loops" ]) == {}
783+ assert compare_state_dicts (mid_epoch_ckpt ["lr_schedulers" ][0 ], mid_epoch_ckpt_v1 ["lr_schedulers" ][0 ]) == {}
784+ assert mid_epoch_ckpt ["epoch" ] == mid_epoch_ckpt_v1 ["epoch" ]
785+ assert mid_epoch_ckpt ["global_step" ] == mid_epoch_ckpt_v1 ["global_step" ]
786+ assert compare_state_dicts (mid_epoch_ckpt ["state_dict" ], mid_epoch_ckpt_v1 ["state_dict" ]) == {}
787+
788+ end_of_epoch_ckpt = torch .load (str (tmp_path / "epoch=1-step=8.ckpt" ), weights_only = True )
789+ end_of_epoch_ckpt_v1 = torch .load (str (tmp_path / "epoch=1-step=8-v1.ckpt" ), weights_only = True )
790+
791+ assert compare_state_dicts (end_of_epoch_ckpt ["loops" ], end_of_epoch_ckpt_v1 ["loops" ]) == {}
792+ assert compare_state_dicts (end_of_epoch_ckpt ["lr_schedulers" ][0 ], end_of_epoch_ckpt_v1 ["lr_schedulers" ][0 ]) == {}
793+ assert end_of_epoch_ckpt ["epoch" ] == end_of_epoch_ckpt_v1 ["epoch" ]
794+ assert end_of_epoch_ckpt ["global_step" ] == end_of_epoch_ckpt_v1 ["global_step" ]
795+ assert compare_state_dicts (end_of_epoch_ckpt ["state_dict" ], end_of_epoch_ckpt_v1 ["state_dict" ]) == {}
796+
797+
733798@pytest .mark .parametrize (
734799 ("train_datasets" , "val_datasets" ),
735800 [([RandomDataset ], [RandomDataset ]), ([RandomDataset ], [RandomDataset , RandomDataset ])],
0 commit comments