@@ -871,18 +871,27 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir):
871
871
hparams_path = Path (cli .trainer .log_dir ) / "hparams.yaml"
872
872
assert hparams_path .is_file ()
873
873
hparams = yaml .safe_load (hparams_path .read_text ())
874
- expected = {
875
- "_instantiator" : "lightning.pytorch.cli.instantiate_module" ,
876
- "optimizer" : "torch.optim.Adam" ,
877
- "scheduler" : "torch.optim.lr_scheduler.ConstantLR" ,
878
- "activation" : {"class_path" : "torch.nn.LeakyReLU" , "init_args" : {"negative_slope" : 0.05 , "inplace" : False }},
879
- }
880
- assert hparams == expected
874
+
875
+ expected_keys = ["_instantiator" , "activation" , "optimizer" , "scheduler" ]
876
+ expected_instantiator = "lightning.pytorch.cli.instantiate_module"
877
+ expected_activation = "torch.nn.LeakyReLU"
878
+ expected_optimizer = "torch.optim.Adam"
879
+ expected_scheduler = "torch.optim.lr_scheduler.ConstantLR"
880
+
881
+ assert sorted (hparams .keys ()) == expected_keys
882
+ assert hparams ["_instantiator" ] == expected_instantiator
883
+ assert hparams ["activation" ]["class_path" ] == expected_activation
884
+ assert hparams ["optimizer" ] == expected_optimizer or hparams ["optimizer" ]["class_path" ] == expected_optimizer
885
+ assert hparams ["scheduler" ] == expected_scheduler or hparams ["scheduler" ]["class_path" ] == expected_scheduler
881
886
882
887
checkpoint_path = next (Path (cli .trainer .log_dir , "checkpoints" ).glob ("*.ckpt" ), None )
883
888
assert checkpoint_path .is_file ()
884
- ckpt = torch .load (checkpoint_path , weights_only = True )
885
- assert ckpt ["hyper_parameters" ] == expected
889
+ hparams = torch .load (checkpoint_path , weights_only = True )["hyper_parameters" ]
890
+ assert sorted (hparams .keys ()) == expected_keys
891
+ assert hparams ["_instantiator" ] == expected_instantiator
892
+ assert hparams ["activation" ]["class_path" ] == expected_activation
893
+ assert hparams ["optimizer" ] == expected_optimizer or hparams ["optimizer" ]["class_path" ] == expected_optimizer
894
+ assert hparams ["scheduler" ] == expected_scheduler or hparams ["scheduler" ]["class_path" ] == expected_scheduler
886
895
887
896
model = TestModelSaveHparams .load_from_checkpoint (checkpoint_path )
888
897
assert isinstance (model , TestModelSaveHparams )
@@ -898,18 +907,23 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(c
898
907
cli = LightningCLI (TestModelSaveHparams , run = False , auto_configure_optimizers = False , subclass_mode_model = True )
899
908
cli .trainer .fit (cli .model )
900
909
901
- expected = {
902
- "_instantiator" : "lightning.pytorch.cli.instantiate_module" ,
903
- "_class_path" : f"{ __name__ } .TestModelSaveHparams" ,
904
- "optimizer" : "torch.optim.Adam" ,
905
- "scheduler" : "torch.optim.lr_scheduler.ConstantLR" ,
906
- "activation" : {"class_path" : "torch.nn.LeakyReLU" , "init_args" : {"negative_slope" : 0.05 , "inplace" : False }},
907
- }
910
+ expected_keys = ["_class_path" , "_instantiator" , "activation" , "optimizer" , "scheduler" ]
911
+ expected_instantiator = "lightning.pytorch.cli.instantiate_module"
912
+ expected_class_path = f"{ __name__ } .TestModelSaveHparams"
913
+ expected_activation = "torch.nn.LeakyReLU"
914
+ expected_optimizer = "torch.optim.Adam"
915
+ expected_scheduler = "torch.optim.lr_scheduler.ConstantLR"
908
916
909
917
checkpoint_path = next (Path (cli .trainer .log_dir , "checkpoints" ).glob ("*.ckpt" ), None )
910
918
assert checkpoint_path .is_file ()
911
- ckpt = torch .load (checkpoint_path , weights_only = True )
912
- assert ckpt ["hyper_parameters" ] == expected
919
+ hparams = torch .load (checkpoint_path , weights_only = True )["hyper_parameters" ]
920
+
921
+ assert sorted (hparams .keys ()) == expected_keys
922
+ assert hparams ["_instantiator" ] == expected_instantiator
923
+ assert hparams ["_class_path" ] == expected_class_path
924
+ assert hparams ["activation" ]["class_path" ] == expected_activation
925
+ assert hparams ["optimizer" ] == expected_optimizer or hparams ["optimizer" ]["class_path" ] == expected_optimizer
926
+ assert hparams ["scheduler" ] == expected_scheduler or hparams ["scheduler" ]["class_path" ] == expected_scheduler
913
927
914
928
model = LightningModule .load_from_checkpoint (checkpoint_path )
915
929
assert isinstance (model , TestModelSaveHparams )
0 commit comments