Skip to content

Commit 4c07c96

Browse files
grajat90rohitgr7
authored andcommitted
Fixed handling on enums in hyperparams save method, Added unit test, Comments implemented
1 parent c3614f1 commit 4c07c96

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
224224
- Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972))
225225

226226

227+
- Added `use_omegaconf` argument to `save_hparams_to_yaml` plugin ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170))
228+
227229

228230
### Changed
229231

@@ -238,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
238240
Old [neptune-client](https://github.com/neptune-ai/neptune-client) API is supported by `NeptuneClient` from [neptune-contrib](https://github.com/neptune-ai/neptune-contrib) repo.
239241

240242

243+
- Parsing of `enums` type hyperparameters to be saved in the `haprams.yaml` file by tensorboard and csv loggers has been fixed and made in line with how omegaconf parses it. ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170))
244+
245+
241246
- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
242247

243248

pytorch_lightning/core/saving.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
from argparse import Namespace
2121
from copy import deepcopy
22+
from enum import Enum
2223
from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union
2324
from warnings import warn
2425

@@ -381,6 +382,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
381382
# drop paramaters which contain some strange datatypes as fsspec
382383
for k, v in hparams.items():
383384
try:
385+
if isinstance(v, Enum):
386+
v = deepcopy(v.name)
384387
yaml.dump(v)
385388
except TypeError:
386389
warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")

tests/models/test_hparams.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717
import pickle
1818
from argparse import Namespace
1919
from dataclasses import dataclass
20+
from enum import Enum
2021
from unittest import mock
2122

2223
import cloudpickle
2324
import pytest
2425
import torch
2526
from fsspec.implementations.local import LocalFileSystem
2627
from omegaconf import Container, OmegaConf
28+
from omegaconf.dictconfig import DictConfig
29+
2730
from torch.utils.data import DataLoader
2831

2932
from pytorch_lightning import LightningModule, Trainer
@@ -477,22 +480,36 @@ def test_hparams_pickle_warning(tmpdir):
477480

478481

479482
def test_hparams_save_yaml(tmpdir):
483+
class Options(str, Enum):
484+
option1name = "option1val"
485+
option2name = "option2val"
486+
option3name = "option3val"
480487
hparams = dict(
481-
batch_size=32, learning_rate=0.001, data_root="./any/path/here", nasted=dict(any_num=123, anystr="abcd")
488+
batch_size=32, learning_rate=0.001, data_root="./any/path/here", nested=dict(any_num=123, anystr="abcd"),
489+
switch= Options.option3name
482490
)
483491
path_yaml = os.path.join(tmpdir, "testing-hparams.yaml")
484492

493+
def compare(loadedParams, defaultParams: dict):
494+
assert isinstance(loadedParams, (dict, DictConfig))
495+
assert loadedParams.keys() == defaultParams.keys()
496+
for k,v in defaultParams.items():
497+
if isinstance(v, Enum):
498+
assert v.name == loadedParams[k]
499+
else:
500+
assert v == loadedParams[k]
501+
485502
save_hparams_to_yaml(path_yaml, hparams)
486-
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
503+
compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
487504

488505
save_hparams_to_yaml(path_yaml, Namespace(**hparams))
489-
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
506+
compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
490507

491508
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
492-
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
509+
compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False),hparams)
493510

494511
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
495-
assert load_hparams_from_yaml(path_yaml) == hparams
512+
compare(load_hparams_from_yaml(path_yaml), hparams)
496513

497514

498515
class NoArgsSubClassBoringModel(CustomBoringModel):

0 commit comments

Comments
 (0)