Skip to content

Commit 8aec4da

Browse files
committed
Fixed handling on enums in hyperparams save method, Added unit test, Comments implemented
1 parent 46b00a7 commit 8aec4da

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
102102

103103
### Changed
104104

105+
- Parsing of `enums` type hyperparameters to be saved in the `haprams.yaml` file by tensorboard and csv loggers has been fixed and made in line with how omegaconf parses it. ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170))
106+
107+
105108
- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
106109

107110

pytorch_lightning/core/saving.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
from argparse import Namespace
2121
from copy import deepcopy
22+
from enum import Enum
2223
from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union
2324
from warnings import warn
2425

@@ -383,6 +384,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
383384
# drop paramaters which contain some strange datatypes as fsspec
384385
for k, v in hparams.items():
385386
try:
387+
if isinstance(v, Enum):
388+
v = deepcopy(v.name)
386389
yaml.dump(v)
387390
except TypeError:
388391
warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")

tests/models/test_hparams.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717
import pickle
1818
from argparse import Namespace
1919
from dataclasses import dataclass
20+
from enum import Enum
2021
from unittest import mock
2122

2223
import cloudpickle
2324
import pytest
2425
import torch
2526
from fsspec.implementations.local import LocalFileSystem
2627
from omegaconf import Container, OmegaConf
28+
from omegaconf.dictconfig import DictConfig
29+
2730
from torch.utils.data import DataLoader
2831

2932
from pytorch_lightning import LightningModule, Trainer
@@ -477,22 +480,36 @@ def test_hparams_pickle_warning(tmpdir):
477480

478481

479482
def test_hparams_save_yaml(tmpdir):
483+
class Options(str, Enum):
484+
option1name = "option1val"
485+
option2name = "option2val"
486+
option3name = "option3val"
480487
hparams = dict(
481-
batch_size=32, learning_rate=0.001, data_root="./any/path/here", nasted=dict(any_num=123, anystr="abcd")
488+
batch_size=32, learning_rate=0.001, data_root="./any/path/here", nested=dict(any_num=123, anystr="abcd"),
489+
switch= Options.option3name
482490
)
483491
path_yaml = os.path.join(tmpdir, "testing-hparams.yaml")
484492

493+
def compare(loadedParams, defaultParams: dict):
494+
assert isinstance(loadedParams, (dict, DictConfig))
495+
assert loadedParams.keys() == defaultParams.keys()
496+
for k,v in defaultParams.items():
497+
if isinstance(v, Enum):
498+
assert v.name == loadedParams[k]
499+
else:
500+
assert v == loadedParams[k]
501+
485502
save_hparams_to_yaml(path_yaml, hparams)
486-
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
503+
compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
487504

488505
save_hparams_to_yaml(path_yaml, Namespace(**hparams))
489-
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
506+
compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
490507

491508
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
492-
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
509+
compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False),hparams)
493510

494511
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
495-
assert load_hparams_from_yaml(path_yaml) == hparams
512+
compare(load_hparams_from_yaml(path_yaml), hparams)
496513

497514

498515
class NoArgsSubClassBoringModel(CustomBoringModel):

0 commit comments

Comments
 (0)