Skip to content

Commit 1e16681

Browse files
authored
fix loading with hparams (#2403)
* fix #2386 * extra test * extra case * extra test * chlog * fix test
1 parent 058c500 commit 1e16681

File tree

3 files changed

+94
-18
lines changed

3 files changed

+94
-18
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4444

4545
- Fixed loading past checkpoints from v0.7.x ([#2405](https://github.com/PyTorchLightning/pytorch-lightning/pull/2405))
4646

47+
- Fixed loading model without arguments ([#2403](https://github.com/PyTorchLightning/pytorch-lightning/pull/2403))
48+
4749
## [0.8.1] - 2020-06-19
4850

4951
### Fixed

pytorch_lightning/core/saving.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def load_from_checkpoint(
171171

172172
@classmethod
173173
def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
174+
cls_spec = inspect.getfullargspec(cls.__init__)
175+
cls_init_args_name = inspect.signature(cls).parameters.keys()
174176
# pass in the values we saved automatically
175177
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
176178
model_args = {}
@@ -183,23 +185,25 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
183185
model_args = _convert_loaded_hparams(model_args, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
184186

185187
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
186-
cls_spec = inspect.getfullargspec(cls.__init__)
187-
kwargs_identifier = cls_spec.varkw
188-
cls_init_args_name = inspect.signature(cls).parameters.keys()
189188

190189
if args_name == 'kwargs':
191190
# in case the class cannot take any extra argument filter only the possible
192-
if not kwargs_identifier:
193-
model_args = {k: v for k, v in model_args.items() if k in cls_init_args_name}
194191
cls_kwargs.update(**model_args)
195192
elif args_name:
196193
if args_name in cls_init_args_name:
197194
cls_kwargs.update({args_name: model_args})
198195
else:
199196
cls_args = (model_args,) + cls_args
200197

201-
# load the state_dict on the model automatically
198+
if not cls_spec.varkw:
199+
# filter kwargs according to class init unless it allows any argument via kwargs
200+
cls_kwargs = {k: v for k, v in cls_kwargs.items() if k in cls_init_args_name}
201+
202+
# prevent passing positional arguments if class does not accept any
203+
if len(cls_spec.args) <= 1 and not cls_spec.kwonlyargs:
204+
cls_args, cls_kwargs = [], {}
202205
model = cls(*cls_args, **cls_kwargs)
206+
# load the state_dict on the model automatically
203207
model.load_state_dict(checkpoint['state_dict'])
204208

205209
# give model a chance to load something

tests/models/test_hparams.py

Lines changed: 82 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
import pytest
77
import torch
88
from omegaconf import OmegaConf, Container
9+
from torch.nn import functional as F
10+
from torch.utils.data import DataLoader
911

1012
from pytorch_lightning import Trainer, LightningModule
1113
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
1214
from pytorch_lightning.utilities import AttributeDict
13-
from tests.base import EvalModelTemplate
15+
from tests.base import EvalModelTemplate, TrialMNIST
1416

1517

1618
class SaveHparamsModel(EvalModelTemplate):
@@ -103,16 +105,16 @@ def test_explicit_args_hparams(tmpdir):
103105
"""
104106

105107
# define model
106-
class TestModel(EvalModelTemplate):
108+
class LocalModel(EvalModelTemplate):
107109
def __init__(self, test_arg, test_arg2):
108110
super().__init__()
109111
self.save_hyperparameters('test_arg', 'test_arg2')
110112

111-
model = TestModel(test_arg=14, test_arg2=90)
113+
model = LocalModel(test_arg=14, test_arg2=90)
112114

113115
# run standard test suite
114-
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, TestModel)
115-
model = TestModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)
116+
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, LocalModel)
117+
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)
116118

117119
# config specific tests
118120
assert model.hparams.test_arg2 == 120
@@ -124,16 +126,16 @@ def test_implicit_args_hparams(tmpdir):
124126
"""
125127

126128
# define model
127-
class TestModel(EvalModelTemplate):
129+
class LocalModel(EvalModelTemplate):
128130
def __init__(self, test_arg, test_arg2):
129131
super().__init__()
130132
self.save_hyperparameters()
131133

132-
model = TestModel(test_arg=14, test_arg2=90)
134+
model = LocalModel(test_arg=14, test_arg2=90)
133135

134136
# run standard test suite
135-
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, TestModel)
136-
model = TestModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)
137+
raw_checkpoint_path = _run_standard_hparams_test(tmpdir, model, LocalModel)
138+
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=120)
137139

138140
# config specific tests
139141
assert model.hparams.test_arg2 == 120
@@ -145,12 +147,12 @@ def test_explicit_missing_args_hparams(tmpdir):
145147
"""
146148

147149
# define model
148-
class TestModel(EvalModelTemplate):
150+
class LocalModel(EvalModelTemplate):
149151
def __init__(self, test_arg, test_arg2):
150152
super().__init__()
151153
self.save_hyperparameters('test_arg')
152154

153-
model = TestModel(test_arg=14, test_arg2=90)
155+
model = LocalModel(test_arg=14, test_arg2=90)
154156

155157
# test proper property assignments
156158
assert model.hparams.test_arg == 14
@@ -166,7 +168,7 @@ def __init__(self, test_arg, test_arg2):
166168
assert raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]['test_arg'] == 14
167169

168170
# verify that model loads correctly
169-
model = TestModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123)
171+
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, test_arg2=123)
170172
assert model.hparams.test_arg == 14
171173
assert 'test_arg2' not in model.hparams # test_arg2 is not registered in class init
172174

@@ -427,3 +429,71 @@ def test_hparams_save_yaml(tmpdir):
427429

428430
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
429431
assert load_hparams_from_yaml(path_yaml) == hparams
432+
433+
434+
class NoArgsSubClassEvalModel(EvalModelTemplate):
435+
def __init__(self):
436+
super().__init__()
437+
438+
439+
class SimpleNoArgsModel(LightningModule):
440+
def __init__(self):
441+
super().__init__()
442+
self.l1 = torch.nn.Linear(28 * 28, 10)
443+
444+
def forward(self, x):
445+
return torch.relu(self.l1(x.view(x.size(0), -1)))
446+
447+
def training_step(self, batch, batch_nb):
448+
x, y = batch
449+
loss = F.cross_entropy(self(x), y)
450+
return {'loss': loss, 'log': {'train_loss': loss}}
451+
452+
def test_step(self, batch, batch_nb):
453+
x, y = batch
454+
loss = F.cross_entropy(self(x), y)
455+
return {'loss': loss, 'log': {'train_loss': loss}}
456+
457+
def configure_optimizers(self):
458+
return torch.optim.Adam(self.parameters(), lr=0.02)
459+
460+
461+
@pytest.mark.parametrize("cls", [
462+
SimpleNoArgsModel,
463+
NoArgsSubClassEvalModel,
464+
])
465+
def test_model_nohparams_train_test(tmpdir, cls):
466+
"""Test models that do not tae any argument in init."""
467+
468+
model = cls()
469+
trainer = Trainer(
470+
max_epochs=1,
471+
default_root_dir=tmpdir,
472+
)
473+
474+
train_loader = DataLoader(TrialMNIST(os.getcwd(), train=True, download=True), batch_size=32)
475+
trainer.fit(model, train_loader)
476+
477+
test_loader = DataLoader(TrialMNIST(os.getcwd(), train=False, download=True), batch_size=32)
478+
trainer.test(test_dataloaders=test_loader)
479+
480+
481+
def test_model_ignores_non_exist_kwargument(tmpdir):
482+
"""Test that the model takes only valid class arguments."""
483+
484+
class LocalModel(EvalModelTemplate):
485+
def __init__(self, batch_size=15):
486+
super().__init__(batch_size=batch_size)
487+
self.save_hyperparameters()
488+
489+
model = LocalModel()
490+
assert model.hparams.batch_size == 15
491+
492+
# verify that the checkpoint saved the correct values
493+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
494+
trainer.fit(model)
495+
496+
# verify that we can overwrite whatever we want
497+
raw_checkpoint_path = _raw_checkpoint_path(trainer)
498+
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, non_exist_kwarg=99)
499+
assert 'non_exist_kwarg' not in model.hparams

0 commit comments

Comments
 (0)