Skip to content

Commit 4204ef7

Browse files
Bordajustusschock
andauthored
Bugfix/4156 filter hparams for yaml - fsspec (#4158)
* add test * fix * sleepy boy * chlog * Apply suggestions from code review Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Justus Schock <[email protected]>
1 parent 72f1976 commit 4204ef7

File tree

3 files changed

+35
-2
lines changed

3 files changed

+35
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2323

2424
- Fixed `hparams` saving - save the state when `save_hyperparameters()` is called [in `__init__`] ([#4163](https://github.com/PyTorchLightning/pytorch-lightning/pull/4163))
2525

26+
- Fixed runtime failure while exporting `hparams` to yaml ([#4158](https://github.com/PyTorchLightning/pytorch-lightning/pull/4158))
2627

2728

2829
## [1.0.1] - 2020-10-14

pytorch_lightning/core/saving.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import os
1919
from argparse import Namespace
2020
from typing import Union, Dict, Any, Optional, Callable, MutableMapping
21+
from warnings import warn
2122

2223
import fsspec
2324
import torch
@@ -372,10 +373,21 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
372373
OmegaConf.save(OmegaConf.create(hparams), fp, resolve=True)
373374
return
374375

375-
# saving the standard way
376376
assert isinstance(hparams, dict)
377+
hparams_allowed = {}
378+
# drop paramaters which contain some strange datatypes as fsspec
379+
for k, v in hparams.items():
380+
try:
381+
yaml.dump(v)
382+
except TypeError as err:
383+
warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")
384+
hparams[k] = type(v).__name__
385+
else:
386+
hparams_allowed[k] = v
387+
388+
# saving the standard way
377389
with fs.open(config_yaml, "w", newline="") as fp:
378-
yaml.dump(hparams, fp)
390+
yaml.dump(hparams_allowed, fp)
379391

380392

381393
def convert(val: str) -> Union[int, float, bool, str]:

tests/models/test_hparams.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import cloudpickle
1919
import pytest
2020
import torch
21+
from fsspec.implementations.local import LocalFileSystem
2122
from omegaconf import OmegaConf, Container
2223
from torch.nn import functional as F
2324
from torch.utils.data import DataLoader
@@ -579,3 +580,22 @@ def test_init_arg_with_runtime_change(tmpdir):
579580
path_yaml = os.path.join(trainer.logger.log_dir, trainer.logger.NAME_HPARAMS_FILE)
580581
hparams = load_hparams_from_yaml(path_yaml)
581582
assert hparams.get('running_arg') == 123
583+
584+
585+
class UnsafeParamModel(BoringModel):
586+
def __init__(self, my_path, any_param=123):
587+
super().__init__()
588+
self.save_hyperparameters()
589+
590+
591+
def test_model_with_fsspec_as_parameter(tmpdir):
592+
model = UnsafeParamModel(LocalFileSystem(tmpdir))
593+
trainer = Trainer(
594+
default_root_dir=tmpdir,
595+
limit_train_batches=2,
596+
limit_val_batches=2,
597+
limit_test_batches=2,
598+
max_epochs=1,
599+
)
600+
trainer.fit(model)
601+
trainer.test()

0 commit comments

Comments
 (0)