Skip to content

Commit 348d600

Browse files
committed
add tests to check for save_hyperparameter: ignore
1 parent 03635d2 commit 348d600

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

tests/tests_pytorch/models/test_hparams.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,42 @@ def __init__(self, same_arg="parent_default", other_arg="other"):
440440
assert parent.child.hparams == {"same_arg": "cocofruit"}
441441

442442

443+
@pytest.mark.parametrize("base_class", [HyperparametersMixin, LightningModule, LightningDataModule])
444+
def test_save_hyperparameters_ignore(base_class):
445+
"""Test if `save_hyperparameter` applies the ignore of the save_hyperparameters function to args."""
446+
447+
class PLSubclass(base_class):
448+
def __init__(self, arg1="arg1", arg2="arg2"):
449+
super().__init__()
450+
self.save_hyperparameters(ignore=["arg1"])
451+
452+
pl_instance = PLSubclass(arg1="arg1", arg2="arg2")
453+
assert pl_instance.hparams == {"arg2": "arg2"}
454+
455+
456+
@pytest.mark.parametrize("base_class", [HyperparametersMixin, LightningModule, LightningDataModule])
457+
def test_save_hyperparameters_ignore_under_composition(base_class):
458+
"""Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get
459+
collected and ignore is respected."""
460+
461+
class ChildInComposition(base_class):
462+
def __init__(self, fav_fruit, fav_animal, fav_language):
463+
super().__init__()
464+
self.save_hyperparameters(ignore=["fav_fruit", "fav_animal"])
465+
466+
class ParentInComposition(base_class):
467+
def __init__(self, fav_fruit, fav_framework):
468+
super().__init__()
469+
self.child = ChildInComposition(fav_fruit="dragonfruit", fav_animal="dog", fav_language="python")
470+
471+
class NotPLSubclass: # intentionally not subclassing LightningModule/LightningDataModule
472+
def __init__(self, parent_arg="parent_default", other_arg="other"):
473+
self.pl_parent = ParentInComposition(fav_fruit="cocofruit", fav_framework="lightning")
474+
475+
parent = NotPLSubclass()
476+
assert parent.pl_parent.child.hparams == {"fav_framework": "lightning", "fav_language": "python"}
477+
478+
443479
class LocalVariableModelSuperLast(BoringModel):
444480
"""This model has the super().__init__() call at the end."""
445481

0 commit comments

Comments
 (0)