Skip to content

Commit 7cc0749

Browse files
committed
update
1 parent 348d600 commit 7cc0749

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

tests/tests_pytorch/models/test_hparams.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -442,39 +442,38 @@ def __init__(self, same_arg="parent_default", other_arg="other"):
442442

443443
@pytest.mark.parametrize("base_class", [HyperparametersMixin, LightningModule, LightningDataModule])
444444
def test_save_hyperparameters_ignore(base_class):
445-
"""Test if `save_hyperparameter` applies the ignore of the save_hyperparameters function to args."""
445+
"""Test if `save_hyperparameter` applies the ignore list correctly during initialization."""
446446

447447
class PLSubclass(base_class):
448-
def __init__(self, arg1="arg1", arg2="arg2"):
448+
def __init__(self, learning_rate=1e-3, optimizer="adam"):
449449
super().__init__()
450-
self.save_hyperparameters(ignore=["arg1"])
450+
self.save_hyperparameters(ignore=["learning_rate"])
451+
452+
pl_instance = PLSubclass(learning_rate=0.01, optimizer="sgd")
453+
assert pl_instance.hparams == {"optimizer": "sgd"}
451454

452-
pl_instance = PLSubclass(arg1="arg1", arg2="arg2")
453-
assert pl_instance.hparams == {"arg2": "arg2"}
454455

455456

456457
@pytest.mark.parametrize("base_class", [HyperparametersMixin, LightningModule, LightningDataModule])
457458
def test_save_hyperparameters_ignore_under_composition(base_class):
458-
"""Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get
459-
collected and ignore is respected."""
459+
"""Test that in a composed system, hyperparameter saving skips ignored fields from nested modules."""
460460

461-
class ChildInComposition(base_class):
462-
def __init__(self, fav_fruit, fav_animal, fav_language):
461+
class ChildModule(base_class):
462+
def __init__(self, dropout, activation, init_method):
463463
super().__init__()
464-
self.save_hyperparameters(ignore=["fav_fruit", "fav_animal"])
464+
self.save_hyperparameters(ignore=["dropout", "activation"])
465465

466-
class ParentInComposition(base_class):
467-
def __init__(self, fav_fruit, fav_framework):
466+
class ParentModule(base_class):
467+
def __init__(self, batch_size, optimizer):
468468
super().__init__()
469-
self.child = ChildInComposition(fav_fruit="dragonfruit", fav_animal="dog", fav_language="python")
469+
self.child = ChildModule(dropout=0.1, activation="relu", init_method="xavier")
470470

471-
class NotPLSubclass: # intentionally not subclassing LightningModule/LightningDataModule
472-
def __init__(self, parent_arg="parent_default", other_arg="other"):
473-
self.pl_parent = ParentInComposition(fav_fruit="cocofruit", fav_framework="lightning")
474-
475-
parent = NotPLSubclass()
476-
assert parent.pl_parent.child.hparams == {"fav_framework": "lightning", "fav_language": "python"}
471+
class PipelineWrapper: # not a Lightning subclass on purpose
472+
def __init__(self, run_id="abc123", seed=42):
473+
self.parent_module = ParentModule(batch_size=64, optimizer="adam")
477474

475+
pipeline = PipelineWrapper()
476+
assert pipeline.parent_module.child.hparams == {"init_method": "xavier", "batch_size": 64, "optimizer": "adam"}
478477

479478
class LocalVariableModelSuperLast(BoringModel):
480479
"""This model has the super().__init__() call at the end."""

0 commit comments

Comments
 (0)