Skip to content

Commit bfe3e8a

Browse files
mauvilsalantiga
andauthored
Change LightningCLI tests to account for future fix in jsonargparse (#20372)
Co-authored-by: Luca Antiga <[email protected]>
1 parent bd5866b commit bfe3e8a

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

tests/tests_pytorch/test_cli.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -871,18 +871,27 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection(cleandir):
871871
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
872872
assert hparams_path.is_file()
873873
hparams = yaml.safe_load(hparams_path.read_text())
874-
expected = {
875-
"_instantiator": "lightning.pytorch.cli.instantiate_module",
876-
"optimizer": "torch.optim.Adam",
877-
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
878-
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
879-
}
880-
assert hparams == expected
874+
875+
expected_keys = ["_instantiator", "activation", "optimizer", "scheduler"]
876+
expected_instantiator = "lightning.pytorch.cli.instantiate_module"
877+
expected_activation = "torch.nn.LeakyReLU"
878+
expected_optimizer = "torch.optim.Adam"
879+
expected_scheduler = "torch.optim.lr_scheduler.ConstantLR"
880+
881+
assert sorted(hparams.keys()) == expected_keys
882+
assert hparams["_instantiator"] == expected_instantiator
883+
assert hparams["activation"]["class_path"] == expected_activation
884+
assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer
885+
assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler
881886

882887
checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
883888
assert checkpoint_path.is_file()
884-
ckpt = torch.load(checkpoint_path, weights_only=True)
885-
assert ckpt["hyper_parameters"] == expected
889+
hparams = torch.load(checkpoint_path, weights_only=True)["hyper_parameters"]
890+
assert sorted(hparams.keys()) == expected_keys
891+
assert hparams["_instantiator"] == expected_instantiator
892+
assert hparams["activation"]["class_path"] == expected_activation
893+
assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer
894+
assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler
886895

887896
model = TestModelSaveHparams.load_from_checkpoint(checkpoint_path)
888897
assert isinstance(model, TestModelSaveHparams)
@@ -898,18 +907,23 @@ def test_lightning_cli_load_from_checkpoint_dependency_injection_subclass_mode(c
898907
cli = LightningCLI(TestModelSaveHparams, run=False, auto_configure_optimizers=False, subclass_mode_model=True)
899908
cli.trainer.fit(cli.model)
900909

901-
expected = {
902-
"_instantiator": "lightning.pytorch.cli.instantiate_module",
903-
"_class_path": f"{__name__}.TestModelSaveHparams",
904-
"optimizer": "torch.optim.Adam",
905-
"scheduler": "torch.optim.lr_scheduler.ConstantLR",
906-
"activation": {"class_path": "torch.nn.LeakyReLU", "init_args": {"negative_slope": 0.05, "inplace": False}},
907-
}
910+
expected_keys = ["_class_path", "_instantiator", "activation", "optimizer", "scheduler"]
911+
expected_instantiator = "lightning.pytorch.cli.instantiate_module"
912+
expected_class_path = f"{__name__}.TestModelSaveHparams"
913+
expected_activation = "torch.nn.LeakyReLU"
914+
expected_optimizer = "torch.optim.Adam"
915+
expected_scheduler = "torch.optim.lr_scheduler.ConstantLR"
908916

909917
checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"), None)
910918
assert checkpoint_path.is_file()
911-
ckpt = torch.load(checkpoint_path, weights_only=True)
912-
assert ckpt["hyper_parameters"] == expected
919+
hparams = torch.load(checkpoint_path, weights_only=True)["hyper_parameters"]
920+
921+
assert sorted(hparams.keys()) == expected_keys
922+
assert hparams["_instantiator"] == expected_instantiator
923+
assert hparams["_class_path"] == expected_class_path
924+
assert hparams["activation"]["class_path"] == expected_activation
925+
assert hparams["optimizer"] == expected_optimizer or hparams["optimizer"]["class_path"] == expected_optimizer
926+
assert hparams["scheduler"] == expected_scheduler or hparams["scheduler"]["class_path"] == expected_scheduler
913927

914928
model = LightningModule.load_from_checkpoint(checkpoint_path)
915929
assert isinstance(model, TestModelSaveHparams)

0 commit comments

Comments
 (0)