Skip to content

Commit 33e57ce

Browse files
pre-commit-ci[bot]rohitgr7
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 4c07c96 commit 33e57ce

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

tests/models/test_hparams.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from fsspec.implementations.local import LocalFileSystem
2727
from omegaconf import Container, OmegaConf
2828
from omegaconf.dictconfig import DictConfig
29-
3029
from torch.utils.data import DataLoader
3130

3231
from pytorch_lightning import LightningModule, Trainer
@@ -484,16 +483,20 @@ class Options(str, Enum):
484483
option1name = "option1val"
485484
option2name = "option2val"
486485
option3name = "option3val"
486+
487487
hparams = dict(
488-
batch_size=32, learning_rate=0.001, data_root="./any/path/here", nested=dict(any_num=123, anystr="abcd"),
489-
switch= Options.option3name
488+
batch_size=32,
489+
learning_rate=0.001,
490+
data_root="./any/path/here",
491+
nested=dict(any_num=123, anystr="abcd"),
492+
switch=Options.option3name,
490493
)
491494
path_yaml = os.path.join(tmpdir, "testing-hparams.yaml")
492495

493496
def compare(loadedParams, defaultParams: dict):
494497
assert isinstance(loadedParams, (dict, DictConfig))
495498
assert loadedParams.keys() == defaultParams.keys()
496-
for k,v in defaultParams.items():
499+
for k, v in defaultParams.items():
497500
if isinstance(v, Enum):
498501
assert v.name == loadedParams[k]
499502
else:
@@ -506,7 +509,7 @@ def compare(loadedParams, defaultParams: dict):
506509
compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
507510

508511
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
509-
compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False),hparams)
512+
compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
510513

511514
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
512515
compare(load_hparams_from_yaml(path_yaml), hparams)

0 commit comments

Comments
 (0)