|
17 | 17 | import pickle |
18 | 18 | from argparse import Namespace |
19 | 19 | from dataclasses import dataclass |
| 20 | +from enum import Enum |
20 | 21 | from unittest import mock |
21 | 22 |
|
22 | 23 | import cloudpickle |
23 | 24 | import pytest |
24 | 25 | import torch |
25 | 26 | from fsspec.implementations.local import LocalFileSystem |
26 | 27 | from omegaconf import Container, OmegaConf |
| 28 | +from omegaconf.dictconfig import DictConfig |
| 29 | + |
27 | 30 | from torch.utils.data import DataLoader |
28 | 31 |
|
29 | 32 | from pytorch_lightning import LightningModule, Trainer |
@@ -477,22 +480,36 @@ def test_hparams_pickle_warning(tmpdir): |
477 | 480 |
|
478 | 481 |
|
479 | 482 | def test_hparams_save_yaml(tmpdir): |
| 483 | + class Options(str, Enum): |
| 484 | + option1name = "option1val" |
| 485 | + option2name = "option2val" |
| 486 | + option3name = "option3val" |
480 | 487 | hparams = dict( |
481 | | - batch_size=32, learning_rate=0.001, data_root="./any/path/here", nasted=dict(any_num=123, anystr="abcd") |
| 488 | + batch_size=32, learning_rate=0.001, data_root="./any/path/here", nested=dict(any_num=123, anystr="abcd"), |
| 489 | + switch= Options.option3name |
482 | 490 | ) |
483 | 491 | path_yaml = os.path.join(tmpdir, "testing-hparams.yaml") |
484 | 492 |
|
| 493 | + def compare(loadedParams, defaultParams: dict): |
| 494 | + assert isinstance(loadedParams, (dict, DictConfig)) |
| 495 | + assert loadedParams.keys() == defaultParams.keys() |
| 496 | + for k,v in defaultParams.items(): |
| 497 | + if isinstance(v, Enum): |
| 498 | + assert v.name == loadedParams[k] |
| 499 | + else: |
| 500 | + assert v == loadedParams[k] |
| 501 | + |
485 | 502 | save_hparams_to_yaml(path_yaml, hparams) |
486 | | - assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams |
| 503 | + compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) |
487 | 504 |
|
488 | 505 | save_hparams_to_yaml(path_yaml, Namespace(**hparams)) |
489 | | - assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams |
| 506 | + compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) |
490 | 507 |
|
491 | 508 | save_hparams_to_yaml(path_yaml, AttributeDict(hparams)) |
492 | | - assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams |
| 509 | + compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False),hparams) |
493 | 510 |
|
494 | 511 | save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams)) |
495 | | - assert load_hparams_from_yaml(path_yaml) == hparams |
| 512 | + compare(load_hparams_from_yaml(path_yaml), hparams) |
496 | 513 |
|
497 | 514 |
|
498 | 515 | class NoArgsSubClassBoringModel(CustomBoringModel): |
|
0 commit comments