@@ -442,39 +442,38 @@ def __init__(self, same_arg="parent_default", other_arg="other"):
442442
443443@pytest .mark .parametrize ("base_class" , [HyperparametersMixin , LightningModule , LightningDataModule ])
444444def test_save_hyperparameters_ignore (base_class ):
445- """Test if `save_hyperparameter` applies the ignore of the save_hyperparameters function to args ."""
445+ """Test if `save_hyperparameter` applies the ignore list correctly during initialization ."""
446446
447447 class PLSubclass (base_class ):
448- def __init__ (self , arg1 = "arg1" , arg2 = "arg2 " ):
448+ def __init__ (self , learning_rate = 1e-3 , optimizer = "adam " ):
449449 super ().__init__ ()
450- self .save_hyperparameters (ignore = ["arg1" ])
450+ self .save_hyperparameters (ignore = ["learning_rate" ])
451+
452+ pl_instance = PLSubclass (learning_rate = 0.01 , optimizer = "sgd" )
453+ assert pl_instance .hparams == {"optimizer" : "sgd" }
451454
452- pl_instance = PLSubclass (arg1 = "arg1" , arg2 = "arg2" )
453- assert pl_instance .hparams == {"arg2" : "arg2" }
454455
455456
456457@pytest .mark .parametrize ("base_class" , [HyperparametersMixin , LightningModule , LightningDataModule ])
457458def test_save_hyperparameters_ignore_under_composition (base_class ):
458- """Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get
459- collected and ignore is respected."""
459+ """Test that in a composed system, hyperparameter saving skips ignored fields from nested modules."""
460460
461- class ChildInComposition (base_class ):
462- def __init__ (self , fav_fruit , fav_animal , fav_language ):
461+ class ChildModule (base_class ):
462+ def __init__ (self , dropout , activation , init_method ):
463463 super ().__init__ ()
464- self .save_hyperparameters (ignore = ["fav_fruit " , "fav_animal " ])
464+ self .save_hyperparameters (ignore = ["dropout " , "activation " ])
465465
466- class ParentInComposition (base_class ):
467- def __init__ (self , fav_fruit , fav_framework ):
466+ class ParentModule (base_class ):
467+ def __init__ (self , batch_size , optimizer ):
468468 super ().__init__ ()
469- self .child = ChildInComposition ( fav_fruit = "dragonfruit" , fav_animal = "dog " , fav_language = "python " )
469+ self .child = ChildModule ( dropout = 0.1 , activation = "relu " , init_method = "xavier " )
470470
471- class NotPLSubclass : # intentionally not subclassing LightningModule/LightningDataModule
472- def __init__ (self , parent_arg = "parent_default" , other_arg = "other" ):
473- self .pl_parent = ParentInComposition (fav_fruit = "cocofruit" , fav_framework = "lightning" )
474-
475- parent = NotPLSubclass ()
476- assert parent .pl_parent .child .hparams == {"fav_framework" : "lightning" , "fav_language" : "python" }
471+ class PipelineWrapper : # not a Lightning subclass on purpose
472+ def __init__ (self , run_id = "abc123" , seed = 42 ):
473+ self .parent_module = ParentModule (batch_size = 64 , optimizer = "adam" )
477474
475+ pipeline = PipelineWrapper ()
476+ assert pipeline .parent_module .child .hparams == {"init_method" : "xavier" , "batch_size" : 64 , "optimizer" : "adam" }
478477
479478class LocalVariableModelSuperLast (BoringModel ):
480479 """This model has the super().__init__() call at the end."""
0 commit comments