@@ -145,7 +145,7 @@ def test_save_restore_dataloader_state(self) -> None:
145145 self .assertEqual (stateful_dataloader .load_state_dict_call_count , 1 )
146146 self .assertEqual (
147147 log .output [0 ],
148- "WARNING:torchtnt.utils.rank_zero_log:train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot " ,
148+ "WARNING:torchtnt.framework.callbacks.dcp_saver:dataloader (train) was passed to `restore` but no dataloader exists in checkpoint metadata. " ,
149149 )
150150
151151 def test_restore_from_latest (self ) -> None :
@@ -500,7 +500,7 @@ def test_save_restore_multi_optimizers(self) -> None:
500500 my_unit_clone = DummyMultiOptimUnit (input_dim = input_dim )
501501 dcp_cb .restore_from_latest (temp_dir , my_unit_clone )
502502
503- def test_save_predict (self ) -> None :
503+ def test_save_restore_predict (self ) -> None :
504504 input_dim = 2
505505 dataset_len = 10
506506 batch_size = 2
@@ -537,19 +537,51 @@ def test_save_predict(self) -> None:
537537 ckpt_path = none_throws (get_latest_checkpoint_path (temp_dir ))
538538 self .assertEqual (ckpt_path , os .path .join (temp_dir , expected_ckpts [- 1 ]))
539539
540+ expected_keys = [
541+ "predict_progress" ,
542+ "predict_dataloader" ,
543+ "output_mean" ,
544+ ]
545+
540546 storage_reader = FsspecReader (ckpt_path )
541547 metadata = storage_reader .read_metadata ()
542548 self .assertCountEqual (
543549 # Get base keys after the app_state wrapper
544550 {key .split ("." )[1 ] for key in metadata .state_dict_metadata .keys ()},
545- [
546- "predict_progress" ,
547- "predict_dataloader" ,
548- "output_mean" ,
549- ],
551+ expected_keys ,
552+ )
553+
554+ # Now make sure that the same exact keys are used when restoring
555+ with patch ("torchtnt.framework.callbacks.dcp_saver.dcp.load" ) as mock_load :
556+ DistributedCheckpointSaver .restore (
557+ ckpt_path , my_unit , predict_dataloader = dataloader
558+ )
559+ self .assertCountEqual (
560+ [* mock_load .call_args [0 ][0 ]["app_state" ].state_dict ().keys ()],
561+ expected_keys ,
562+ )
563+
564+ # Double check that the module parameters are not overwritten when loading cktp
565+ my_unit = DummyPredictUnit (input_dim = input_dim )
566+ my_unit .module .weight .data .fill_ (0.0 )
567+ my_unit .module .bias .data .fill_ (1.0 )
568+
569+ DistributedCheckpointSaver .restore (
570+ ckpt_path , my_unit , predict_dataloader = dataloader
571+ )
572+
573+ self .assertTrue (
574+ torch .allclose (
575+ my_unit .module .weight .data , torch .zeros (input_dim , input_dim )
576+ )
577+ )
578+ self .assertTrue (
579+ torch .allclose (
580+ my_unit .module .bias .data , torch .ones (input_dim , input_dim )
581+ )
550582 )
551583
552- def test_save_evaluate (self ) -> None :
584+ def test_save_restore_evaluate (self ) -> None :
553585 input_dim = 2
554586 dataset_len = 10
555587 batch_size = 2
@@ -580,18 +612,49 @@ def test_save_evaluate(self) -> None:
580612 ckpt_path = none_throws (get_latest_checkpoint_path (temp_dir ))
581613 self .assertEqual (ckpt_path , os .path .join (temp_dir , expected_ckpts [- 1 ]))
582614
615+ expected_keys = [
616+ "eval_progress" ,
617+ "eval_dataloader" ,
618+ ]
583619 storage_reader = FsspecReader (ckpt_path )
584620 metadata = storage_reader .read_metadata ()
585621 self .assertCountEqual (
586622 # Get base keys after the app_state wrapper
587623 {key .split ("." )[1 ] for key in metadata .state_dict_metadata .keys ()},
588- [
589- "eval_progress" ,
590- "eval_dataloader" ,
591- ],
624+ expected_keys ,
592625 )
593626
594- def test_save_fit_eval_every_n_epochs (self ) -> None :
627+ # Now make sure that the same exact keys are used when restoring
628+ with patch ("torchtnt.framework.callbacks.dcp_saver.dcp.load" ) as mock_load :
629+ DistributedCheckpointSaver .restore (
630+ ckpt_path , my_unit , eval_dataloader = dataloader
631+ )
632+ self .assertCountEqual (
633+ [* mock_load .call_args [0 ][0 ]["app_state" ].state_dict ().keys ()],
634+ expected_keys ,
635+ )
636+
637+ # Double check that the module parameters are not overwritten when loading cktp
638+ my_unit = DummyEvalUnit (input_dim = input_dim )
639+ my_unit .module .weight .data .fill_ (0.0 )
640+ my_unit .module .bias .data .fill_ (1.0 )
641+
642+ DistributedCheckpointSaver .restore (
643+ ckpt_path , my_unit , predict_dataloader = dataloader
644+ )
645+
646+ self .assertTrue (
647+ torch .allclose (
648+ my_unit .module .weight .data , torch .zeros (input_dim , input_dim )
649+ )
650+ )
651+ self .assertTrue (
652+ torch .allclose (
653+ my_unit .module .bias .data , torch .ones (input_dim , input_dim )
654+ )
655+ )
656+
657+ def test_save_restore_fit_eval_every_n_epochs (self ) -> None :
595658 input_dim = 2
596659 dataset_len = 10
597660 batch_size = 2
@@ -625,33 +688,52 @@ def test_save_fit_eval_every_n_epochs(self) -> None:
625688 )
626689
627690 generated_ckpts = os .listdir (temp_dir )
628- expected_ckpts = [
629- "epoch_0_train_step_2_eval_step_0" ,
630- "epoch_0_train_step_4_eval_step_0" ,
631- "epoch_1_train_step_5_eval_step_2" ,
632- "epoch_1_train_step_5_eval_step_4" ,
691+ expected_ckpts_to_dl_mapping : Dict [str , str ] = {
692+ "epoch_0_train_step_2_eval_step_0" : "train_dataloader" ,
693+ "epoch_0_train_step_4_eval_step_0" : "train_dataloader" ,
694+ "epoch_1_train_step_5_eval_step_2" : "eval_dataloader" ,
695+ "epoch_1_train_step_5_eval_step_4" : "eval_dataloader" ,
696+ }
697+ self .assertCountEqual (
698+ generated_ckpts , [* expected_ckpts_to_dl_mapping .keys ()]
699+ )
700+
701+ expected_keys = [
702+ "module" , # Both train and eval checkpoints save full app_state in fit
703+ "optimizer" ,
704+ "lr_scheduler" ,
705+ "train_progress" ,
706+ "eval_progress" ,
707+ "predict_progress" , # included because of AutoUnit
708+ "output_mean" ,
633709 ]
634- self .assertCountEqual (generated_ckpts , expected_ckpts )
635710
636- expected_dataloader = ["train_dataloader" ] * 2 + ["eval_dataloader" ] * 2
637- for ckpt_path , dl_key in zip (expected_ckpts , expected_dataloader ):
638- storage_reader = FsspecReader (os .path .join (temp_dir , ckpt_path ))
711+ for ckpt_path , dl_key in expected_ckpts_to_dl_mapping .items ():
712+ full_ckpt_path = os .path .join (temp_dir , ckpt_path )
713+ expected_keys_with_dl = expected_keys + [dl_key ]
714+ storage_reader = FsspecReader (full_ckpt_path )
639715 metadata = storage_reader .read_metadata ()
640716 self .assertCountEqual (
641717 # Get base keys after the app_state wrapper
642718 {key .split ("." )[1 ] for key in metadata .state_dict_metadata .keys ()},
643- [
644- "module" , # Both train and eval checkpoints save full app_state in fit
645- "optimizer" ,
646- "lr_scheduler" ,
647- "train_progress" ,
648- "eval_progress" ,
649- "predict_progress" , # included because of AutoUnit
650- dl_key ,
651- "output_mean" ,
652- ],
719+ expected_keys_with_dl ,
653720 )
654721
722+ # Now make sure that the same exact keys are used when restoring
723+ with patch (
724+ "torchtnt.framework.callbacks.dcp_saver.dcp.load"
725+ ) as mock_load :
726+ DistributedCheckpointSaver .restore (
727+ full_ckpt_path ,
728+ my_unit ,
729+ train_dataloader = train_dataloader ,
730+ eval_dataloader = eval_dataloader ,
731+ )
732+ self .assertCountEqual (
733+ [* mock_load .call_args [0 ][0 ]["app_state" ].state_dict ().keys ()],
734+ expected_keys_with_dl ,
735+ )
736+
655737 def test_save_fit_eval_every_n_steps (self ) -> None :
656738 input_dim = 2
657739
@@ -710,24 +792,42 @@ def test_save_fit_eval_every_n_steps(self) -> None:
710792 generated_ckpts , [* expected_ckpts_to_dl_mapping .keys ()]
711793 )
712794
795+ expected_keys = [
796+ "module" , # Both train and eval checkpoints save full app_state in fit
797+ "optimizer" ,
798+ "lr_scheduler" ,
799+ "train_progress" ,
800+ "eval_progress" ,
801+ "predict_progress" , # included because of AutoUnit
802+ "output_mean" ,
803+ ]
804+
713805 for ckpt_path , expected_dls in expected_ckpts_to_dl_mapping .items ():
714- storage_reader = FsspecReader (os .path .join (temp_dir , ckpt_path ))
806+ expected_keys_with_dls = [* expected_keys , * expected_dls ]
807+ full_ckpt_path = os .path .join (temp_dir , ckpt_path )
808+ storage_reader = FsspecReader (full_ckpt_path )
715809 metadata = storage_reader .read_metadata ()
716810 self .assertCountEqual (
717811 # Get base keys after the app_state wrapper
718812 {key .split ("." )[1 ] for key in metadata .state_dict_metadata .keys ()},
719- [
720- "module" , # Both train and eval checkpoints save full app_state in fit
721- "optimizer" ,
722- "lr_scheduler" ,
723- "train_progress" ,
724- "eval_progress" ,
725- "predict_progress" , # included because of AutoUnit
726- "output_mean" ,
727- * expected_dls ,
728- ],
813+ expected_keys_with_dls ,
729814 )
730815
816+ # Now make sure that the same exact keys are used when restoring
817+ with patch (
818+ "torchtnt.framework.callbacks.dcp_saver.dcp.load"
819+ ) as mock_load :
820+ DistributedCheckpointSaver .restore (
821+ full_ckpt_path ,
822+ my_unit ,
823+ train_dataloader = train_dataloader ,
824+ eval_dataloader = eval_dataloader ,
825+ )
826+ self .assertCountEqual (
827+ [* mock_load .call_args [0 ][0 ]["app_state" ].state_dict ().keys ()],
828+ expected_keys_with_dls ,
829+ )
830+
731831
732832class DummyStatefulDataLoader :
733833 def __init__ (self , dataloader : DataLoader ) -> None :
0 commit comments