66import pytest
77import torch
88from omegaconf import OmegaConf , Container
9+ from torch .nn import functional as F
10+ from torch .utils .data import DataLoader
911
1012from pytorch_lightning import Trainer , LightningModule
1113from pytorch_lightning .core .saving import save_hparams_to_yaml , load_hparams_from_yaml
1214from pytorch_lightning .utilities import AttributeDict
13- from tests .base import EvalModelTemplate
15+ from tests .base import EvalModelTemplate , TrialMNIST
1416
1517
1618class SaveHparamsModel (EvalModelTemplate ):
@@ -103,16 +105,16 @@ def test_explicit_args_hparams(tmpdir):
103105 """
104106
105107 # define model
106- class TestModel (EvalModelTemplate ):
108+ class LocalModel (EvalModelTemplate ):
107109 def __init__ (self , test_arg , test_arg2 ):
108110 super ().__init__ ()
109111 self .save_hyperparameters ('test_arg' , 'test_arg2' )
110112
111- model = TestModel (test_arg = 14 , test_arg2 = 90 )
113+ model = LocalModel (test_arg = 14 , test_arg2 = 90 )
112114
113115 # run standard test suite
114- raw_checkpoint_path = _run_standard_hparams_test (tmpdir , model , TestModel )
115- model = TestModel .load_from_checkpoint (raw_checkpoint_path , test_arg2 = 120 )
116+ raw_checkpoint_path = _run_standard_hparams_test (tmpdir , model , LocalModel )
117+ model = LocalModel .load_from_checkpoint (raw_checkpoint_path , test_arg2 = 120 )
116118
117119 # config specific tests
118120 assert model .hparams .test_arg2 == 120
@@ -124,16 +126,16 @@ def test_implicit_args_hparams(tmpdir):
124126 """
125127
126128 # define model
127- class TestModel (EvalModelTemplate ):
129+ class LocalModel (EvalModelTemplate ):
128130 def __init__ (self , test_arg , test_arg2 ):
129131 super ().__init__ ()
130132 self .save_hyperparameters ()
131133
132- model = TestModel (test_arg = 14 , test_arg2 = 90 )
134+ model = LocalModel (test_arg = 14 , test_arg2 = 90 )
133135
134136 # run standard test suite
135- raw_checkpoint_path = _run_standard_hparams_test (tmpdir , model , TestModel )
136- model = TestModel .load_from_checkpoint (raw_checkpoint_path , test_arg2 = 120 )
137+ raw_checkpoint_path = _run_standard_hparams_test (tmpdir , model , LocalModel )
138+ model = LocalModel .load_from_checkpoint (raw_checkpoint_path , test_arg2 = 120 )
137139
138140 # config specific tests
139141 assert model .hparams .test_arg2 == 120
@@ -145,12 +147,12 @@ def test_explicit_missing_args_hparams(tmpdir):
145147 """
146148
147149 # define model
148- class TestModel (EvalModelTemplate ):
150+ class LocalModel (EvalModelTemplate ):
149151 def __init__ (self , test_arg , test_arg2 ):
150152 super ().__init__ ()
151153 self .save_hyperparameters ('test_arg' )
152154
153- model = TestModel (test_arg = 14 , test_arg2 = 90 )
155+ model = LocalModel (test_arg = 14 , test_arg2 = 90 )
154156
155157 # test proper property assignments
156158 assert model .hparams .test_arg == 14
@@ -166,7 +168,7 @@ def __init__(self, test_arg, test_arg2):
166168 assert raw_checkpoint [LightningModule .CHECKPOINT_HYPER_PARAMS_KEY ]['test_arg' ] == 14
167169
168170 # verify that model loads correctly
169- model = TestModel .load_from_checkpoint (raw_checkpoint_path , test_arg2 = 123 )
171+ model = LocalModel .load_from_checkpoint (raw_checkpoint_path , test_arg2 = 123 )
170172 assert model .hparams .test_arg == 14
171173 assert 'test_arg2' not in model .hparams # test_arg2 is not registered in class init
172174
@@ -427,3 +429,71 @@ def test_hparams_save_yaml(tmpdir):
427429
428430 save_hparams_to_yaml (path_yaml , OmegaConf .create (hparams ))
429431 assert load_hparams_from_yaml (path_yaml ) == hparams
432+
433+
434+ class NoArgsSubClassEvalModel (EvalModelTemplate ):
435+ def __init__ (self ):
436+ super ().__init__ ()
437+
438+
439+ class SimpleNoArgsModel (LightningModule ):
440+ def __init__ (self ):
441+ super ().__init__ ()
442+ self .l1 = torch .nn .Linear (28 * 28 , 10 )
443+
444+ def forward (self , x ):
445+ return torch .relu (self .l1 (x .view (x .size (0 ), - 1 )))
446+
447+ def training_step (self , batch , batch_nb ):
448+ x , y = batch
449+ loss = F .cross_entropy (self (x ), y )
450+ return {'loss' : loss , 'log' : {'train_loss' : loss }}
451+
452+ def test_step (self , batch , batch_nb ):
453+ x , y = batch
454+ loss = F .cross_entropy (self (x ), y )
455+ return {'loss' : loss , 'log' : {'train_loss' : loss }}
456+
457+ def configure_optimizers (self ):
458+ return torch .optim .Adam (self .parameters (), lr = 0.02 )
459+
460+
461+ @pytest .mark .parametrize ("cls" , [
462+ SimpleNoArgsModel ,
463+ NoArgsSubClassEvalModel ,
464+ ])
465+ def test_model_nohparams_train_test (tmpdir , cls ):
466+ """Test models that do not tae any argument in init."""
467+
468+ model = cls ()
469+ trainer = Trainer (
470+ max_epochs = 1 ,
471+ default_root_dir = tmpdir ,
472+ )
473+
474+ train_loader = DataLoader (TrialMNIST (os .getcwd (), train = True , download = True ), batch_size = 32 )
475+ trainer .fit (model , train_loader )
476+
477+ test_loader = DataLoader (TrialMNIST (os .getcwd (), train = False , download = True ), batch_size = 32 )
478+ trainer .test (test_dataloaders = test_loader )
479+
480+
481+ def test_model_ignores_non_exist_kwargument (tmpdir ):
482+ """Test that the model takes only valid class arguments."""
483+
484+ class LocalModel (EvalModelTemplate ):
485+ def __init__ (self , batch_size = 15 ):
486+ super ().__init__ (batch_size = batch_size )
487+ self .save_hyperparameters ()
488+
489+ model = LocalModel ()
490+ assert model .hparams .batch_size == 15
491+
492+ # verify that the checkpoint saved the correct values
493+ trainer = Trainer (default_root_dir = tmpdir , max_epochs = 1 )
494+ trainer .fit (model )
495+
496+ # verify that we can overwrite whatever we want
497+ raw_checkpoint_path = _raw_checkpoint_path (trainer )
498+ model = LocalModel .load_from_checkpoint (raw_checkpoint_path , non_exist_kwarg = 99 )
499+ assert 'non_exist_kwarg' not in model .hparams
0 commit comments