From 759bcee3854bd6b3d02ca280f48ab8c974ce916b Mon Sep 17 00:00:00 2001 From: "grajat90@gmail.com" Date: Fri, 27 Aug 2021 19:29:59 +0530 Subject: [PATCH 1/6] Fixed handling on enums in hyperparams save method, Added unit test, Comments implemented --- CHANGELOG.md | 6 ++++++ pytorch_lightning/core/saving.py | 3 +++ tests/models/test_hparams.py | 27 ++++++++++++++++++++++----- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 16e754038165c..4a8ac255506b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -221,6 +221,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `XLACheckpointIO` plugin ([#9972](https://github.com/PyTorchLightning/pytorch-lightning/pull/9972)) +- Added `use_omegaconf` argument to `save_hparams_to_yaml` plugin ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170)) + + ### Changed - Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)). @@ -234,6 +237,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). Old [neptune-client](https://github.com/neptune-ai/neptune-client) API is supported by `NeptuneClient` from [neptune-contrib](https://github.com/neptune-ai/neptune-contrib) repo. +- Parsing of `enums` type hyperparameters to be saved in the `haprams.yaml` file by tensorboard and csv loggers has been fixed and made in line with how omegaconf parses it. ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170)) + + - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 782fbe99c0425..04ecfe7335794 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -19,6 +19,7 @@ import os from argparse import Namespace from copy import deepcopy +from enum import Enum from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union from warnings import warn @@ -381,6 +382,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: # drop paramaters which contain some strange datatypes as fsspec for k, v in hparams.items(): try: + if isinstance(v, Enum): + v = deepcopy(v.name) yaml.dump(v) except TypeError: warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.") diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index a95bdf9ba76d7..e4dc6d1ec5e96 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -17,6 +17,7 @@ import pickle from argparse import Namespace from dataclasses import dataclass +from enum import Enum from unittest import mock import cloudpickle @@ -24,6 +25,8 @@ import torch from fsspec.implementations.local import LocalFileSystem from omegaconf import Container, OmegaConf +from omegaconf.dictconfig import DictConfig + from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, Trainer @@ -477,22 +480,36 @@ def test_hparams_pickle_warning(tmpdir): def test_hparams_save_yaml(tmpdir): + class Options(str, Enum): + option1name = "option1val" + option2name = "option2val" + option3name = "option3val" hparams = dict( - batch_size=32, learning_rate=0.001, data_root="./any/path/here", nasted=dict(any_num=123, anystr="abcd") + batch_size=32, learning_rate=0.001, data_root="./any/path/here", nested=dict(any_num=123, anystr="abcd"), + switch= Options.option3name ) path_yaml = os.path.join(tmpdir, "testing-hparams.yaml") + def compare(loadedParams, defaultParams: dict): + assert isinstance(loadedParams, (dict, DictConfig)) + assert loadedParams.keys() == defaultParams.keys() + for k,v in defaultParams.items(): + if isinstance(v, Enum): + assert v.name == loadedParams[k] + else: + assert v == loadedParams[k] + save_hparams_to_yaml(path_yaml, hparams) - assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams + compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) save_hparams_to_yaml(path_yaml, Namespace(**hparams)) - assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams + compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) save_hparams_to_yaml(path_yaml, AttributeDict(hparams)) - assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams + compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False),hparams) save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams)) - assert load_hparams_from_yaml(path_yaml) == hparams + compare(load_hparams_from_yaml(path_yaml), hparams) class NoArgsSubClassBoringModel(CustomBoringModel): From d74e0a8ae6d92934887355b67610c57fe2bb6a0b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 29 Aug 2021 06:35:21 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/models/test_hparams.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index e4dc6d1ec5e96..fc69e09c56e82 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -26,7 +26,6 @@ from fsspec.implementations.local import LocalFileSystem from omegaconf import Container, OmegaConf from omegaconf.dictconfig import DictConfig - from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, Trainer @@ -484,16 +483,20 @@ class Options(str, Enum): option1name = "option1val" option2name = "option2val" option3name = "option3val" + hparams = dict( - batch_size=32, learning_rate=0.001, data_root="./any/path/here", nested=dict(any_num=123, anystr="abcd"), - switch= Options.option3name + batch_size=32, + learning_rate=0.001, + data_root="./any/path/here", + nested=dict(any_num=123, anystr="abcd"), + switch=Options.option3name, ) path_yaml = os.path.join(tmpdir, "testing-hparams.yaml") def compare(loadedParams, defaultParams: dict): assert isinstance(loadedParams, (dict, DictConfig)) assert loadedParams.keys() == defaultParams.keys() - for k,v in defaultParams.items(): + for k, v in defaultParams.items(): if isinstance(v, Enum): assert v.name == loadedParams[k] else: @@ -506,7 +509,7 @@ def compare(loadedParams, defaultParams: dict): compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) save_hparams_to_yaml(path_yaml, AttributeDict(hparams)) - compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False),hparams) + compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams)) compare(load_hparams_from_yaml(path_yaml), hparams) From 76eb5dc1c2d61a09e9b84aa4de0521f853ab6d8d Mon Sep 17 00:00:00 2001 From: Rajat Date: Mon, 30 Aug 2021 19:18:43 +0530 Subject: [PATCH 3/6] omegaconf flag added --- pytorch_lightning/core/saving.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 04ecfe7335794..6cbebc8d5643a 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -347,11 +347,14 @@ def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict return hparams -def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: +def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace], use_omegaconf: bool = True) -> None: """ Args: config_yaml: path to new YAML file hparams: parameters to be saved + use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True, + the hparams will be converted to `DictConfig` if possible + """ fs = get_filesystem(config_yaml) if not fs.isdir(os.path.dirname(config_yaml)): @@ -364,7 +367,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: hparams = dict(hparams) # saving with OmegaConf objects - if _OMEGACONF_AVAILABLE: + if _OMEGACONF_AVAILABLE and use_omegaconf: # deepcopy: hparams from user shouldn't be resolved hparams = deepcopy(hparams) hparams = apply_to_collection(hparams, DictConfig, OmegaConf.to_container, resolve=True) From fcd0fe1ee11644ff85084783891b83d9201f62ac Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 12 Oct 2021 14:55:31 +0100 Subject: [PATCH 4/6] resolve comments --- pytorch_lightning/core/saving.py | 8 ++++---- tests/models/test_hparams.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 6cbebc8d5643a..3d7e2da0a993a 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -319,8 +319,8 @@ def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict Args: config_yaml: Path to config yaml file - use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True, - the hparams will be converted to `DictConfig` if possible + use_omegaconf: If omegaconf is available and `use_omegaconf=True`, + the hparams will be converted to ``DictConfig`` if possible. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') >>> path_yaml = './testing-hparams.yaml' @@ -352,8 +352,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace], use_omega Args: config_yaml: path to new YAML file hparams: parameters to be saved - use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True, - the hparams will be converted to `DictConfig` if possible + use_omegaconf: If omegaconf is available and `use_omegaconf=True`, + the hparams will be converted to ``DictConfig`` if possible. """ fs = get_filesystem(config_yaml) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index fc69e09c56e82..16734e2c35a84 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -493,14 +493,14 @@ class Options(str, Enum): ) path_yaml = os.path.join(tmpdir, "testing-hparams.yaml") - def compare(loadedParams, defaultParams: dict): - assert isinstance(loadedParams, (dict, DictConfig)) - assert loadedParams.keys() == defaultParams.keys() - for k, v in defaultParams.items(): + def compare(loaded_params, default_params: dict): + assert isinstance(loaded_params, (dict, DictConfig)) + assert loaded_params.keys() == default_params.keys() + for k, v in default_params.items(): if isinstance(v, Enum): - assert v.name == loadedParams[k] + assert v.name == loaded_params[k] else: - assert v == loadedParams[k] + assert v == loaded_params[k] save_hparams_to_yaml(path_yaml, hparams) compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) From 62db00e7b1791cc0a68211e09063213e23f116df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 13 Oct 2021 11:08:03 +0200 Subject: [PATCH 5/6] Apply suggestions from code review Co-authored-by: Rohit Gupta --- pytorch_lightning/core/saving.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 3d7e2da0a993a..081b01007ec15 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -319,7 +319,7 @@ def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict Args: config_yaml: Path to config yaml file - use_omegaconf: If omegaconf is available and `use_omegaconf=True`, + use_omegaconf: If omegaconf is available and ``use_omegaconf=True``, the hparams will be converted to ``DictConfig`` if possible. >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') @@ -352,8 +352,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace], use_omega Args: config_yaml: path to new YAML file hparams: parameters to be saved - use_omegaconf: If omegaconf is available and `use_omegaconf=True`, - the hparams will be converted to ``DictConfig`` if possible. + use_omegaconf: If omegaconf is available and ``use_omegaconf=True``, + the hparams will be converted to ``DictConfig`` if possible. """ fs = get_filesystem(config_yaml) From db5273388f7b9a9208b8188d558cc31318c78873 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 21 Oct 2021 16:24:11 +0530 Subject: [PATCH 6/6] code reviews and rebase --- pytorch_lightning/core/saving.py | 3 +-- tests/models/test_hparams.py | 10 +++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 081b01007ec15..2f9463e42fec8 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -385,8 +385,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace], use_omega # drop paramaters which contain some strange datatypes as fsspec for k, v in hparams.items(): try: - if isinstance(v, Enum): - v = deepcopy(v.name) + v = v.name if isinstance(v, Enum) else v yaml.dump(v) except TypeError: warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.") diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 16734e2c35a84..dbd51d33bf0ed 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -493,7 +493,7 @@ class Options(str, Enum): ) path_yaml = os.path.join(tmpdir, "testing-hparams.yaml") - def compare(loaded_params, default_params: dict): + def _compare_params(loaded_params, default_params: dict): assert isinstance(loaded_params, (dict, DictConfig)) assert loaded_params.keys() == default_params.keys() for k, v in default_params.items(): @@ -503,16 +503,16 @@ def compare(loaded_params, default_params: dict): assert v == loaded_params[k] save_hparams_to_yaml(path_yaml, hparams) - compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) + _compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) save_hparams_to_yaml(path_yaml, Namespace(**hparams)) - compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) + _compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) save_hparams_to_yaml(path_yaml, AttributeDict(hparams)) - compare(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) + _compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams)) - compare(load_hparams_from_yaml(path_yaml), hparams) + _compare_params(load_hparams_from_yaml(path_yaml), hparams) class NoArgsSubClassBoringModel(CustomBoringModel):