diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 6fd400aab2724..3c7838f11a85a 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -440,6 +440,41 @@ def __init__(self, same_arg="parent_default", other_arg="other"): assert parent.child.hparams == {"same_arg": "cocofruit"} +@pytest.mark.parametrize("base_class", [HyperparametersMixin, LightningModule, LightningDataModule]) +def test_save_hyperparameters_ignore(base_class): + """Test if `save_hyperparameter` applies the ignore list correctly during initialization.""" + + class PLSubclass(base_class): + def __init__(self, learning_rate=1e-3, optimizer="adam"): + super().__init__() + self.save_hyperparameters(ignore=["learning_rate"]) + + pl_instance = PLSubclass(learning_rate=0.01, optimizer="sgd") + assert pl_instance.hparams == {"optimizer": "sgd"} + + +@pytest.mark.parametrize("base_class", [HyperparametersMixin, LightningModule, LightningDataModule]) +def test_save_hyperparameters_ignore_under_composition(base_class): + """Test that in a composed system, hyperparameter saving skips ignored fields from nested modules.""" + + class ChildModule(base_class): + def __init__(self, dropout, activation, init_method): + super().__init__() + self.save_hyperparameters(ignore=["dropout", "activation"]) + + class ParentModule(base_class): + def __init__(self, batch_size, optimizer): + super().__init__() + self.child = ChildModule(dropout=0.1, activation="relu", init_method="xavier") + + class PipelineWrapper: # not a Lightning subclass on purpose + def __init__(self, run_id="abc123", seed=42): + self.parent_module = ParentModule(batch_size=64, optimizer="adam") + + pipeline = PipelineWrapper() + assert pipeline.parent_module.child.hparams == {"init_method": "xavier", "batch_size": 64, "optimizer": "adam"} + + class LocalVariableModelSuperLast(BoringModel): """This model has the super().__init__() call at the end."""