Skip to content

Commit c90dc65

Browse files
committed
Added unit test
1 parent 464d506 commit c90dc65

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

tests/models/test_hparams.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import copy
15+
from enum import Enum
1516
import functools
1617
import os
1718
import pickle
@@ -477,8 +478,13 @@ def test_hparams_pickle_warning(tmpdir):
477478

478479

479480
def test_hparams_save_yaml(tmpdir):
481+
class Options(str, Enum):
482+
option1 = "option1"
483+
option2 = "option2"
484+
option3 = "option3"
480485
hparams = dict(
481-
batch_size=32, learning_rate=0.001, data_root="./any/path/here", nasted=dict(any_num=123, anystr="abcd")
486+
batch_size=32, learning_rate=0.001, data_root="./any/path/here", nasted=dict(any_num=123, anystr="abcd"),
487+
switch= Options.option3
482488
)
483489
path_yaml = os.path.join(tmpdir, "testing-hparams.yaml")
484490

@@ -495,6 +501,7 @@ def test_hparams_save_yaml(tmpdir):
495501
assert load_hparams_from_yaml(path_yaml) == hparams
496502

497503

504+
498505
class NoArgsSubClassBoringModel(CustomBoringModel):
499506
def __init__(self):
500507
super().__init__()

0 commit comments

Comments
 (0)