Skip to content

Commit 861a73b

Browse files
authored
fix loading past checpoints (#2405)
* fix #2334 * chlog
1 parent 66ffbad commit 861a73b

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4242

4343
- Fixed loading model with kwargs ([#2387](https://github.com/PyTorchLightning/pytorch-lightning/pull/2387))
4444

45+
- Fixed loading past checkpoints from v0.7.x ([#2405](https://github.com/PyTorchLightning/pytorch-lightning/pull/2405))
46+
4547
## [0.8.1] - 2020-06-19
4648

4749
### Fixed

pytorch_lightning/core/saving.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
180180
if hparam_key in checkpoint:
181181
model_args.update(checkpoint[hparam_key])
182182

183-
if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint:
184-
model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)
183+
model_args = _convert_loaded_hparams(model_args, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
185184

186185
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
187186
cls_spec = inspect.getfullargspec(cls.__init__)
@@ -248,6 +247,18 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:
248247
"""
249248

250249

250+
def _convert_loaded_hparams(model_args: dict, hparams_type: Union[Callable, str] = None) -> object:
251+
"""Convert hparams according given type in callable or string (past) format"""
252+
# if not hparams type define
253+
if not hparams_type:
254+
return model_args
255+
# if past checkpoint loaded, convert str to callable
256+
if isinstance(hparams_type, str):
257+
hparams_type = AttributeDict
258+
# convert hparams
259+
return hparams_type(model_args)
260+
261+
251262
def update_hparams(hparams: dict, updates: dict) -> None:
252263
"""
253264
Overrides hparams with new values

tests/models/test_hparams.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def test_collect_init_arguments(tmpdir, cls):
268268
# verify that the checkpoint saved the correct values
269269
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
270270
trainer.fit(model)
271+
271272
raw_checkpoint_path = _raw_checkpoint_path(trainer)
272273

273274
raw_checkpoint = torch.load(raw_checkpoint_path)
@@ -391,6 +392,7 @@ def test_load_past_checkpoint(tmpdir, past_key):
391392
raw_checkpoint_path = _raw_checkpoint_path(trainer)
392393
raw_checkpoint = torch.load(raw_checkpoint_path)
393394
raw_checkpoint[past_key] = raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
395+
raw_checkpoint['hparams_type'] = 'Namespace'
394396
raw_checkpoint[past_key]['batch_size'] = -17
395397
del raw_checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
396398
# save back the checkpoint

0 commit comments

Comments
 (0)