Skip to content

Commit 73cb165

Browse files
committed
resolve comments
1 parent 8749bd3 commit 73cb165

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

pytorch_lightning/core/saving.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,8 @@ def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict
319319
320320
Args:
321321
config_yaml: Path to config yaml file
322-
use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True,
323-
the hparams will be converted to `DictConfig` if possible
322+
use_omegaconf: If omegaconf is available and `use_omegaconf=True`,
323+
the hparams will be converted to ``DictConfig`` if possible.
324324
325325
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
326326
>>> path_yaml = './testing-hparams.yaml'
@@ -352,8 +352,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace], use_omega
352352
Args:
353353
config_yaml: path to new YAML file
354354
hparams: parameters to be saved
355-
use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True,
356-
the hparams will be converted to `DictConfig` if possible
355+
use_omegaconf: If omegaconf is available and `use_omegaconf=True`,
356+
the hparams will be converted to ``DictConfig`` if possible.
357357
358358
"""
359359
fs = get_filesystem(config_yaml)

tests/models/test_hparams.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -493,14 +493,14 @@ class Options(str, Enum):
493493
)
494494
path_yaml = os.path.join(tmpdir, "testing-hparams.yaml")
495495

496-
def compare(loadedParams, defaultParams: dict):
497-
assert isinstance(loadedParams, (dict, DictConfig))
498-
assert loadedParams.keys() == defaultParams.keys()
499-
for k, v in defaultParams.items():
496+
def compare(loaded_params, default_params: dict):
497+
assert isinstance(loaded_params, (dict, DictConfig))
498+
assert loaded_params.keys() == default_params.keys()
499+
for k, v in default_params.items():
500500
if isinstance(v, Enum):
501-
assert v.name == loadedParams[k]
501+
assert v.name == loaded_params[k]
502502
else:
503-
assert v == loadedParams[k]
503+
assert v == loaded_params[k]
504504

505505
save_hparams_to_yaml(path_yaml, hparams)
506506
compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)

0 commit comments

Comments
 (0)