@@ -878,18 +878,27 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir):
878878 hparams_path = Path (cli .trainer .log_dir ) / "hparams.yaml"
879879 assert hparams_path .is_file ()
880880 hparams = yaml .safe_load (hparams_path .read_text ())
881- expected = {
882- "_instantiator" : "lightning.pytorch.cli.instantiate_module" ,
883- "optimizer" : "torch.optim.Adam" ,
884- "scheduler" : "torch.optim.lr_scheduler.ConstantLR" ,
885- "activation" : {"class_path" : "torch.nn.LeakyReLU" , "init_args" : {"negative_slope" : 0.05 , "inplace" : False }},
886- }
887- assert hparams == expected
881+
882+ expected_keys = ["_instantiator" , "activation" , "optimizer" , "scheduler" ]
883+ expected_instantiator = "lightning.pytorch.cli.instantiate_module"
884+ expected_activation = "torch.nn.LeakyReLU"
885+ expected_optimizer = "torch.optim.Adam"
886+ expected_scheduler = "torch.optim.lr_scheduler.ConstantLR"
887+
888+ assert sorted (hparams .keys ()) == expected_keys
889+ assert hparams ["_instantiator" ] == expected_instantiator
890+ assert hparams ["activation" ]["class_path" ] == expected_activation
891+ assert hparams ["optimizer" ] == expected_optimizer or hparams ["optimizer" ]["class_path" ] == expected_optimizer
892+ assert hparams ["scheduler" ] == expected_scheduler or hparams ["scheduler" ]["class_path" ] == expected_scheduler
888893
889894 checkpoint_path = next (Path (cli .trainer .log_dir , "checkpoints" ).glob ("*.ckpt" ), None )
890895 assert checkpoint_path .is_file ()
891- ckpt = torch .load (checkpoint_path , weights_only = True )
892- assert ckpt ["hyper_parameters" ] == expected
896+ hparams = torch .load (checkpoint_path , weights_only = True )["hyper_parameters" ]
897+ assert sorted (hparams .keys ()) == expected_keys
898+ assert hparams ["_instantiator" ] == expected_instantiator
899+ assert hparams ["activation" ]["class_path" ] == expected_activation
900+ assert hparams ["optimizer" ] == expected_optimizer or hparams ["optimizer" ]["class_path" ] == expected_optimizer
901+ assert hparams ["scheduler" ] == expected_scheduler or hparams ["scheduler" ]["class_path" ] == expected_scheduler
893902
894903 model = TestModelSaveHparams .load_from_checkpoint (checkpoint_path )
895904 assert isinstance (model , TestModelSaveHparams )
@@ -905,18 +914,23 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(c
905914 cli = LightningCLI (TestModelSaveHparams , run = False , auto_configure_optimizers = False , subclass_mode_model = True )
906915 cli .trainer .fit (cli .model )
907916
908- expected = {
909- "_instantiator" : "lightning.pytorch.cli.instantiate_module" ,
910- "_class_path" : f"{ __name__ } .TestModelSaveHparams" ,
911- "optimizer" : "torch.optim.Adam" ,
912- "scheduler" : "torch.optim.lr_scheduler.ConstantLR" ,
913- "activation" : {"class_path" : "torch.nn.LeakyReLU" , "init_args" : {"negative_slope" : 0.05 , "inplace" : False }},
914- }
917+ expected_keys = ["_class_path" , "_instantiator" , "activation" , "optimizer" , "scheduler" ]
918+ expected_instantiator = "lightning.pytorch.cli.instantiate_module"
919+ expected_class_path = f"{ __name__ } .TestModelSaveHparams"
920+ expected_activation = "torch.nn.LeakyReLU"
921+ expected_optimizer = "torch.optim.Adam"
922+ expected_scheduler = "torch.optim.lr_scheduler.ConstantLR"
915923
916924 checkpoint_path = next (Path (cli .trainer .log_dir , "checkpoints" ).glob ("*.ckpt" ), None )
917925 assert checkpoint_path .is_file ()
918- ckpt = torch .load (checkpoint_path , weights_only = True )
919- assert ckpt ["hyper_parameters" ] == expected
926+ hparams = torch .load (checkpoint_path , weights_only = True )["hyper_parameters" ]
927+
928+ assert sorted (hparams .keys ()) == expected_keys
929+ assert hparams ["_instantiator" ] == expected_instantiator
930+ assert hparams ["_class_path" ] == expected_class_path
931+ assert hparams ["activation" ]["class_path" ] == expected_activation
932+ assert hparams ["optimizer" ] == expected_optimizer or hparams ["optimizer" ]["class_path" ] == expected_optimizer
933+ assert hparams ["scheduler" ] == expected_scheduler or hparams ["scheduler" ]["class_path" ] == expected_scheduler
920934
921935 model = LightningModule .load_from_checkpoint (checkpoint_path )
922936 assert isinstance (model , TestModelSaveHparams )
0 commit comments