Skip to content

Commit 13bbcaa

Browse files
committed
code reviews and rebase
1 parent 4b558d4 commit 13bbcaa

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

pytorch_lightning/core/saving.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,8 +385,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace], use_omega
385385
# drop paramaters which contain some strange datatypes as fsspec
386386
for k, v in hparams.items():
387387
try:
388-
if isinstance(v, Enum):
389-
v = deepcopy(v.name)
388+
v = v.name if isinstance(v, Enum) else v
390389
yaml.dump(v)
391390
except TypeError:
392391
warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")

tests/models/test_hparams.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ class Options(str, Enum):
493493
)
494494
path_yaml = os.path.join(tmpdir, "testing-hparams.yaml")
495495

496-
def compare(loaded_params, default_params: dict):
496+
def _compare_params(loaded_params, default_params: dict):
497497
assert isinstance(loaded_params, (dict, DictConfig))
498498
assert loaded_params.keys() == default_params.keys()
499499
for k, v in default_params.items():
@@ -503,16 +503,16 @@ def compare(loaded_params, default_params: dict):
503503
assert v == loaded_params[k]
504504

505505
save_hparams_to_yaml(path_yaml, hparams)
506-
compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
506+
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
507507

508508
save_hparams_to_yaml(path_yaml, Namespace(**hparams))
509-
compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
509+
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
510510

511511
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
512-
compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
512+
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
513513

514514
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
515-
compare(load_hparams_from_yaml(path_yaml), hparams)
515+
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
516516

517517

518518
class NoArgsSubClassBoringModel(CustomBoringModel):

0 commit comments

Comments
 (0)