Skip to content

Commit cbdc61c

Browse files
committed
add additional testing.
1 parent 6711052 commit cbdc61c

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tests/tests_pytorch/models/test_hparams.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,21 @@ def _raw_checkpoint_path(trainer) -> str:
436436
return raw_checkpoint_path
437437

438438

439+
def test_collect_init_arguments_in_other_methods():
440+
class _ABCModelCreator:
441+
def init(self, model, **kwargs) -> LightningModule:
442+
self.model = model
443+
return self.model
444+
445+
class ConcreteModelCreator(_ABCModelCreator):
446+
def init(self, model=None, **kwargs) -> LightningModule:
447+
return super().init(model=model or CustomBoringModel(**kwargs))
448+
449+
model_creator = ConcreteModelCreator()
450+
model = model_creator.init(batch_size=123)
451+
assert model.hparams.batch_size == 123
452+
453+
439454
@pytest.mark.parametrize("base_class", [HyperparametersMixin, LightningModule, LightningDataModule])
440455
def test_save_hyperparameters_under_composition(base_class):
441456
"""Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get

0 commit comments

Comments
 (0)