@@ -604,16 +604,24 @@ def test_fit_loop_reset(tmp_path):
604604
605605 # we load exactly what was saved - no reset yet
606606 fit_loop .load_state_dict (end_of_epoch_ckpt ["loops" ]["fit_loop" ])
607+
608+ assert fit_loop .restarting
609+ assert fit_loop .epoch_progress .total .ready == 1
610+ assert fit_loop .epoch_progress .total .completed == 0
611+ assert fit_loop .epoch_progress .current .ready == 1
612+ assert fit_loop .epoch_progress .current .completed == 0
613+
607614 # resetting from a end-of-epoch checkpoint SHOULD reset the current counters to 0
608615 fit_loop .reset ()
609616 epoch_loop .reset ()
610617
611618 # resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0
619+ # since we are restarting at the end of epoch, we need to see `completed` being updated after reset
612620 assert fit_loop .restarting
613621 assert fit_loop .epoch_progress .total .ready == 1
614- assert fit_loop .epoch_progress .total .completed == 0
622+ assert fit_loop .epoch_progress .total .completed == 1
615623 assert fit_loop .epoch_progress .current .ready == 1
616- assert fit_loop .epoch_progress .current .completed == 0
624+ assert fit_loop .epoch_progress .current .completed == 1
617625
618626 # however it should increment completed batch progress, since it was saved immediately prior
619627 assert epoch_loop .restarting
@@ -704,6 +712,7 @@ def test_restart_parity(tmp_path):
704712 callbacks = [checkpoint_callback ],
705713 logger = False ,
706714 enable_model_summary = False ,
715+ enable_progress_bar = False ,
707716 )
708717 trainer .fit (model )
709718 loss = model .last_loss
@@ -715,6 +724,7 @@ def test_restart_parity(tmp_path):
715724 callbacks = [checkpoint_callback ],
716725 logger = False ,
717726 enable_model_summary = False ,
727+ enable_progress_bar = False ,
718728 )
719729 trainer .fit (model , ckpt_path = str (tmp_path / "epoch=0-step=2.ckpt" ))
720730 loss_v1 = model .last_loss
@@ -749,7 +759,7 @@ def test_restart_parity(tmp_path):
749759 assert compare_state_dicts (end_of_epoch_ckpt ["state_dict" ], end_of_epoch_ckpt_v1 ["state_dict" ]) == {}
750760
751761
752- def test_restart_parity_with_val (tmp_path ):
762+ def test_restart_with_val_parity (tmp_path ):
753763 model = PredictableBoringModel ()
754764 checkpoint_callback = ModelCheckpoint (
755765 dirpath = tmp_path ,
@@ -814,6 +824,108 @@ def test_restart_parity_with_val(tmp_path):
814824 assert compare_state_dicts (end_of_epoch_ckpt ["state_dict" ], end_of_epoch_ckpt_v1 ["state_dict" ]) == {}
815825
816826
827+ def test_restart_from_last_parity (tmp_path ):
828+ model = PredictableBoringModel ()
829+ checkpoint_callback = ModelCheckpoint (
830+ dirpath = tmp_path ,
831+ save_last = True ,
832+ save_top_k = - 1 ,
833+ )
834+
835+ trainer = Trainer (
836+ default_root_dir = tmp_path ,
837+ limit_train_batches = 2 ,
838+ max_epochs = 4 ,
839+ callbacks = [checkpoint_callback ],
840+ logger = False ,
841+ enable_model_summary = False ,
842+ enable_progress_bar = False ,
843+ )
844+ trainer .fit (model )
845+
846+ last_ckpt_1 = torch .load (str (tmp_path / "last.ckpt" ), weights_only = True )
847+
848+ trainer = Trainer (
849+ default_root_dir = tmp_path ,
850+ limit_train_batches = 2 ,
851+ max_epochs = 2 ,
852+ callbacks = [checkpoint_callback ],
853+ logger = False ,
854+ enable_model_summary = False ,
855+ enable_progress_bar = False ,
856+ )
857+ trainer .fit (model )
858+
859+ trainer = Trainer (
860+ default_root_dir = tmp_path ,
861+ limit_train_batches = 2 ,
862+ max_epochs = 4 ,
863+ callbacks = [checkpoint_callback ],
864+ logger = False ,
865+ enable_model_summary = False ,
866+ enable_progress_bar = False ,
867+ )
868+ trainer .fit (model , ckpt_path = str (tmp_path / "last.ckpt" ))
869+
870+ last_ckpt_2 = torch .load (str (tmp_path / "last.ckpt" ), weights_only = True )
871+
872+ assert compare_state_dicts (last_ckpt_1 ["loops" ], last_ckpt_2 ["loops" ]) == {}
873+
874+
875+ def test_restart_from_last_with_val_parity (tmp_path ):
876+ model = PredictableBoringModel ()
877+ checkpoint_callback = ModelCheckpoint (
878+ dirpath = tmp_path ,
879+ save_last = True ,
880+ save_top_k = - 1 ,
881+ )
882+
883+ trainer = Trainer (
884+ default_root_dir = tmp_path ,
885+ limit_train_batches = 2 ,
886+ max_epochs = 4 ,
887+ callbacks = [checkpoint_callback ],
888+ logger = False ,
889+ enable_model_summary = False ,
890+ enable_progress_bar = False ,
891+ limit_val_batches = 2 ,
892+ val_check_interval = 2 ,
893+ )
894+ trainer .fit (model )
895+
896+ last_ckpt_1 = torch .load (str (tmp_path / "last.ckpt" ), weights_only = True )
897+
898+ trainer = Trainer (
899+ default_root_dir = tmp_path ,
900+ limit_train_batches = 2 ,
901+ max_epochs = 2 ,
902+ callbacks = [checkpoint_callback ],
903+ logger = False ,
904+ enable_model_summary = False ,
905+ enable_progress_bar = False ,
906+ limit_val_batches = 2 ,
907+ val_check_interval = 2 ,
908+ )
909+ trainer .fit (model )
910+
911+ trainer = Trainer (
912+ default_root_dir = tmp_path ,
913+ limit_train_batches = 2 ,
914+ max_epochs = 4 ,
915+ callbacks = [checkpoint_callback ],
916+ logger = False ,
917+ enable_model_summary = False ,
918+ enable_progress_bar = False ,
919+ limit_val_batches = 2 ,
920+ val_check_interval = 2 ,
921+ )
922+ trainer .fit (model , ckpt_path = str (tmp_path / "last.ckpt" ))
923+
924+ last_ckpt_2 = torch .load (str (tmp_path / "last.ckpt" ), weights_only = True )
925+
926+ assert compare_state_dicts (last_ckpt_1 ["loops" ], last_ckpt_2 ["loops" ]) == {}
927+
928+
817929@pytest .mark .parametrize (
818930 ("train_datasets" , "val_datasets" ),
819931 [([RandomDataset ], [RandomDataset ]), ([RandomDataset ], [RandomDataset , RandomDataset ])],
0 commit comments