Skip to content

Commit 47e7a28

Browse files
grajat90tchatonawaelchlirohitgr7
authored
Fix Enums parsing in generated hparms yaml (#9170)
Co-authored-by: tchaton <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent 0e0247a commit 47e7a28

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
226226
- Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972))
227227

228228

229+
- Added `use_omegaconf` argument to `save_hparams_to_yaml` plugin ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170))
230+
231+
229232
- Added `ckpt_path` argument for `trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))
230233

231234

@@ -242,6 +245,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
242245
Old [neptune-client](https://github.com/neptune-ai/neptune-client) API is supported by `NeptuneClient` from [neptune-contrib](https://github.com/neptune-ai/neptune-contrib) repo.
243246

244247

248+
- Parsing of `enums` type hyperparameters to be saved in the `haprams.yaml` file by tensorboard and csv loggers has been fixed and made in line with how omegaconf parses it. ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170))
249+
250+
245251
- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
246252

247253

pytorch_lightning/core/saving.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
from argparse import Namespace
2121
from copy import deepcopy
22+
from enum import Enum
2223
from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union
2324
from warnings import warn
2425

@@ -318,8 +319,8 @@ def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict
318319
319320
Args:
320321
config_yaml: Path to config yaml file
321-
use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True,
322-
the hparams will be converted to `DictConfig` if possible
322+
use_omegaconf: If omegaconf is available and ``use_omegaconf=True``,
323+
the hparams will be converted to ``DictConfig`` if possible.
323324
324325
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
325326
>>> path_yaml = './testing-hparams.yaml'
@@ -346,11 +347,14 @@ def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict
346347
return hparams
347348

348349

349-
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
350+
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace], use_omegaconf: bool = True) -> None:
350351
"""
351352
Args:
352353
config_yaml: path to new YAML file
353354
hparams: parameters to be saved
355+
use_omegaconf: If omegaconf is available and ``use_omegaconf=True``,
356+
the hparams will be converted to ``DictConfig`` if possible.
357+
354358
"""
355359
fs = get_filesystem(config_yaml)
356360
if not fs.isdir(os.path.dirname(config_yaml)):
@@ -363,7 +367,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
363367
hparams = dict(hparams)
364368

365369
# saving with OmegaConf objects
366-
if _OMEGACONF_AVAILABLE:
370+
if _OMEGACONF_AVAILABLE and use_omegaconf:
367371
# deepcopy: hparams from user shouldn't be resolved
368372
hparams = deepcopy(hparams)
369373
hparams = apply_to_collection(hparams, DictConfig, OmegaConf.to_container, resolve=True)
@@ -381,6 +385,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
381385
# drop paramaters which contain some strange datatypes as fsspec
382386
for k, v in hparams.items():
383387
try:
388+
v = v.name if isinstance(v, Enum) else v
384389
yaml.dump(v)
385390
except TypeError:
386391
warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")

tests/models/test_hparams.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
import pickle
1818
from argparse import Namespace
1919
from dataclasses import dataclass
20+
from enum import Enum
2021
from unittest import mock
2122

2223
import cloudpickle
2324
import pytest
2425
import torch
2526
from fsspec.implementations.local import LocalFileSystem
2627
from omegaconf import Container, OmegaConf
28+
from omegaconf.dictconfig import DictConfig
2729
from torch.utils.data import DataLoader
2830

2931
from pytorch_lightning import LightningModule, Trainer
@@ -477,22 +479,40 @@ def test_hparams_pickle_warning(tmpdir):
477479

478480

479481
def test_hparams_save_yaml(tmpdir):
482+
class Options(str, Enum):
483+
option1name = "option1val"
484+
option2name = "option2val"
485+
option3name = "option3val"
486+
480487
hparams = dict(
481-
batch_size=32, learning_rate=0.001, data_root="./any/path/here", nasted=dict(any_num=123, anystr="abcd")
488+
batch_size=32,
489+
learning_rate=0.001,
490+
data_root="./any/path/here",
491+
nested=dict(any_num=123, anystr="abcd"),
492+
switch=Options.option3name,
482493
)
483494
path_yaml = os.path.join(tmpdir, "testing-hparams.yaml")
484495

496+
def _compare_params(loaded_params, default_params: dict):
497+
assert isinstance(loaded_params, (dict, DictConfig))
498+
assert loaded_params.keys() == default_params.keys()
499+
for k, v in default_params.items():
500+
if isinstance(v, Enum):
501+
assert v.name == loaded_params[k]
502+
else:
503+
assert v == loaded_params[k]
504+
485505
save_hparams_to_yaml(path_yaml, hparams)
486-
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
506+
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
487507

488508
save_hparams_to_yaml(path_yaml, Namespace(**hparams))
489-
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
509+
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
490510

491511
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
492-
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
512+
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
493513

494514
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
495-
assert load_hparams_from_yaml(path_yaml) == hparams
515+
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
496516

497517

498518
class NoArgsSubClassBoringModel(CustomBoringModel):

0 commit comments

Comments
 (0)