@@ -440,6 +440,42 @@ def __init__(self, same_arg="parent_default", other_arg="other"):
440440 assert parent .child .hparams == {"same_arg" : "cocofruit" }
441441
442442
443+ @pytest .mark .parametrize ("base_class" , [HyperparametersMixin , LightningModule , LightningDataModule ])
444+ def test_save_hyperparameters_ignore (base_class ):
445+ """Test if `save_hyperparameter` applies the ignore of the save_hyperparameters function to args."""
446+
447+ class PLSubclass (base_class ):
448+ def __init__ (self , arg1 = "arg1" , arg2 = "arg2" ):
449+ super ().__init__ ()
450+ self .save_hyperparameters (ignore = ["arg1" ])
451+
452+ pl_instance = PLSubclass (arg1 = "arg1" , arg2 = "arg2" )
453+ assert pl_instance .hparams == {"arg2" : "arg2" }
454+
455+
456+ @pytest .mark .parametrize ("base_class" , [HyperparametersMixin , LightningModule , LightningDataModule ])
457+ def test_save_hyperparameters_ignore_under_composition (base_class ):
458+ """Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get
459+ collected and ignore is respected."""
460+
461+ class ChildInComposition (base_class ):
462+ def __init__ (self , fav_fruit , fav_animal , fav_language ):
463+ super ().__init__ ()
464+ self .save_hyperparameters (ignore = ["fav_fruit" , "fav_animal" ])
465+
466+ class ParentInComposition (base_class ):
467+ def __init__ (self , fav_fruit , fav_framework ):
468+ super ().__init__ ()
469+ self .child = ChildInComposition (fav_fruit = "dragonfruit" , fav_animal = "dog" , fav_language = "python" )
470+
471+ class NotPLSubclass : # intentionally not subclassing LightningModule/LightningDataModule
472+ def __init__ (self , parent_arg = "parent_default" , other_arg = "other" ):
473+ self .pl_parent = ParentInComposition (fav_fruit = "cocofruit" , fav_framework = "lightning" )
474+
475+ parent = NotPLSubclass ()
476+ assert parent .pl_parent .child .hparams == {"fav_framework" : "lightning" , "fav_language" : "python" }
477+
478+
443479class LocalVariableModelSuperLast (BoringModel ):
444480 """This model has the super().__init__() call at the end."""
445481
0 commit comments