Skip to content

Commit c2e84ad

Browse files
author
The ml_collections Authors
committed
Support typing.Literal in config flags.
PiperOrigin-RevId: 861237487
1 parent 347c0e2 commit c2e84ad

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

ml_collections/config_flags/config_flags.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import sys
2828
import traceback
2929
import types
30+
import typing
3031
from typing import Any, Callable, Dict, Generic, List, MutableMapping, Optional, Sequence, Tuple, Type, TypeVar
3132

3233
from absl import flags
@@ -859,7 +860,10 @@ def _parse(self, argument):
859860
parser = _FIELD_TYPE_TO_PARSER[config_dict.ConfigDict]
860861
elif field_type_origin and field_type_origin in _FIELD_TYPE_TO_PARSER:
861862
parser = _FIELD_TYPE_TO_PARSER[field_type_origin]
862-
elif issubclass(field_type, enum.Enum):
863+
elif field_type_origin is typing.Literal:
864+
# Literal types like Literal["a", "b"] should be treated as string-like.
865+
parser = _LiteralParser()
866+
elif isinstance(field_type, type) and issubclass(field_type, enum.Enum):
863867
parser = flags.EnumClassParser(field_type, case_sensitive=False)
864868
elif dataclasses.is_dataclass(field_type):
865869
# For dataclasses-valued fields allow default instance creation.

ml_collections/config_flags/tests/dataclass_overriding_test.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import dataclasses
1919
import functools
2020
import sys
21-
from typing import Mapping, Optional, Sequence, Tuple, Union
21+
from typing import Literal, Mapping, Optional, Sequence, Tuple, Union
2222
import unittest
2323

2424
from absl import flags
@@ -179,6 +179,17 @@ class PipeConfig:
179179
result = _test_flags(PipeConfig(), '.foo=32')
180180
self.assertEqual(result.foo, 32)
181181

182+
def test_literal_type_field(self):
183+
@dataclasses.dataclass
184+
class ConfigWithLiteral:
185+
optimizer: Literal['adam', 'sgd', 'rmsprop'] = 'adam'
186+
187+
result = _test_flags(ConfigWithLiteral())
188+
self.assertEqual(result.optimizer, 'adam')
189+
190+
result = _test_flags(ConfigWithLiteral(), '.optimizer=sgd')
191+
self.assertEqual(result.optimizer, 'sgd')
192+
182193
def test_custom_flag_parsing_override_work(self):
183194
# Overrides still work.
184195
result = _test_flags(_CONFIG, '.custom.i=10')

0 commit comments

Comments
 (0)