Skip to content

Commit 99e038b

Browse files
ConchylicultorThe ml_collections Authors
authored andcommitted
Always use Literal evals when parsing kauldron flags
PiperOrigin-RevId: 872335743
1 parent 91e91a5 commit 99e038b

File tree

3 files changed

+160
-35
lines changed

3 files changed

+160
-35
lines changed

ml_collections/config_flags/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .config_flags import DEFINE_config_file
2020
from .config_flags import get_config_filename
2121
from .config_flags import get_override_values
22+
from .config_flags import OverrideMode
2223
from .config_flags import register_flag_parser
2324
from .config_flags import register_flag_parser_for_type
2425

@@ -28,6 +29,7 @@
2829
"DEFINE_config_file",
2930
"get_config_filename",
3031
"get_override_values",
32+
"OverrideMode",
3133
"register_flag_parser",
3234
"register_flag_parser_for_type",
3335
)

ml_collections/config_flags/config_flags.py

Lines changed: 94 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Configuration commmand line parser."""
1616

1717
import ast
18+
import contextlib
1819
import copy
1920
import dataclasses
2021
import enum
@@ -48,6 +49,41 @@
4849
flags.disclaim_key_flags()
4950

5051

52+
class OverrideMode(enum.Enum):
53+
"""Behavior to overwrite config parameters types.
54+
55+
Attributes:
56+
ALWAYS: All `--cfg.xxx` flags are parsed using `_LiteralParser`.
57+
NEW_ONLY: `--cfg.xxx` is only parsed using `_LiteralParser`, if `cfg.xxx`
58+
does not exist yet.
59+
NEVER: (default) All `--cfg.xxx` flags are parsed using type-specific
60+
parsers (i.e. it is not possible to have flag supporting both int and
61+
str).
62+
"""
63+
64+
ALWAYS = enum.auto()
65+
NEW_ONLY = enum.auto()
66+
NEVER = enum.auto()
67+
68+
69+
def _normalize_accept_new_attributes(
70+
accept_new_attributes: bool | None = None,
71+
override_mode: OverrideMode | None = None,
72+
) -> OverrideMode:
73+
"""Normalizes override_mode."""
74+
if accept_new_attributes is not None and override_mode is not None:
75+
raise ValueError(
76+
'Cannot specify both `accept_new_attributes` and `override_mode`.'
77+
)
78+
if override_mode is not None:
79+
return override_mode
80+
if accept_new_attributes is None:
81+
return OverrideMode.NEVER
82+
if accept_new_attributes:
83+
return OverrideMode.NEW_ONLY
84+
return OverrideMode.NEVER
85+
86+
5187
def _load_source(module_name: str, module_path: str) -> types.ModuleType:
5288
"""Loads a Python module from its source file.
5389
@@ -145,9 +181,11 @@ def DEFINE_config_file( # pylint: disable=g-bad-name
145181
help_string: str = 'path to config file.',
146182
flag_values: flags.FlagValues = FLAGS,
147183
lock_config: bool = True,
148-
accept_new_attributes: bool = False,
184+
accept_new_attributes: bool | None = None,
185+
override_mode: OverrideMode | None = None,
149186
sys_argv: Optional[List[str]] = None,
150-
**kwargs) -> flags.FlagHolder:
187+
**kwargs,
188+
) -> flags.FlagHolder:
151189
r"""Defines flag for `ConfigDict` files compatible with absl flags.
152190
153191
The flag's value should be a path to a valid python file which contains a
@@ -247,9 +285,10 @@ def get_config(config_string):
247285
absl.flags.FLAGS)
248286
lock_config: If set to True, loaded config will be locked through calling
249287
.lock() method on its instance (if it exists). (default: True)
250-
accept_new_attributes: If `True`, accept to pass arbitrary attributes that
251-
are not originally defined in the `get_config()` dict.
252-
`accept_new_attributes` requires `lock_config=False`.
288+
accept_new_attributes: Deprecated, use `override_mode` instead. `False`:
289+
only existing attributes can be overridden. `True`: new attributes are
290+
accepted. Requires `lock_config=False`.
291+
override_mode: Controls how override flags are parsed.
253292
sys_argv: If set, interprets this as the full list of args used in parsing.
254293
This is used to identify which overrides to define as flags. If not
255294
specified, uses the system sys.argv to figure it out.
@@ -258,8 +297,12 @@ def get_config(config_string):
258297
Returns:
259298
a handle to defined flag.
260299
"""
261-
if accept_new_attributes and lock_config:
262-
raise ValueError('`accept_new_attributes=True` requires lock_config=False')
300+
301+
override_mode = _normalize_accept_new_attributes(
302+
accept_new_attributes, override_mode
303+
)
304+
if override_mode != OverrideMode.NEVER and lock_config:
305+
raise ValueError('`override_mode` requires `lock_config=False`')
263306
parser = ConfigFileFlagParser(name=name, lock_config=lock_config)
264307
serializer = flags.ArgumentSerializer()
265308
flag = _ConfigFlag(
@@ -269,9 +312,10 @@ def get_config(config_string):
269312
default=default,
270313
help_string=help_string,
271314
flag_values=flag_values,
272-
accept_new_attributes=accept_new_attributes,
315+
override_mode=override_mode,
273316
sys_argv=sys_argv,
274-
**kwargs)
317+
**kwargs,
318+
)
275319

276320
return flags.DEFINE_flag(flag, flag_values)
277321

@@ -282,9 +326,11 @@ def DEFINE_config_dict( # pylint: disable=g-bad-name
282326
help_string: str = 'ConfigDict instance.',
283327
flag_values: flags.FlagValues = FLAGS,
284328
lock_config: bool = True,
285-
accept_new_attributes: bool = False,
329+
accept_new_attributes: bool | None = None,
330+
override_mode: OverrideMode | None = None,
286331
sys_argv: Optional[List[str]] = None,
287-
**kwargs) -> flags.FlagHolder:
332+
**kwargs,
333+
) -> flags.FlagHolder:
288334
"""Defines flag for inline `ConfigDict's` compatible with absl flags.
289335
290336
Similar to `DEFINE_config_file` except the flag's value should be a
@@ -328,15 +374,16 @@ def DEFINE_config_dict( # pylint: disable=g-bad-name
328374
Args:
329375
name: Flag name.
330376
config: `ConfigDict` object.
331-
help_string: Help string to display when --helpfull is called.
332-
(default: "ConfigDict instance.")
333-
flag_values: FlagValues instance used for parsing.
334-
(default: absl.flags.FLAGS)
377+
help_string: Help string to display when --helpfull is called. (default:
378+
"ConfigDict instance.")
379+
flag_values: FlagValues instance used for parsing. (default:
380+
absl.flags.FLAGS)
335381
lock_config: If set to True, loaded config will be locked through calling
336-
.lock() method on its instance (if it exists). (default: True)
337-
accept_new_attributes: If `True`, accept to pass arbitrary attributes that
338-
are not originally defined in the `config` argument.
339-
`accept_new_attributes` requires `lock_config=False`.
382+
.lock() method on its instance (if it exists). (default: True)
383+
accept_new_attributes: Deprecated, use `override_mode` instead. `False`:
384+
only existing attributes can be overridden. `True`: new attributes are
385+
accepted. Requires `lock_config=False`.
386+
override_mode: Controls how override flags are parsed.
340387
sys_argv: If set, interprets this as the full list of args used in parsing.
341388
This is used to identify which overrides to define as flags. If not
342389
specified, uses the system sys.argv to figure it out.
@@ -347,8 +394,11 @@ def DEFINE_config_dict( # pylint: disable=g-bad-name
347394
"""
348395
if not isinstance(config, config_dict.ConfigDict):
349396
raise TypeError('config should be a ConfigDict')
350-
if accept_new_attributes and lock_config:
351-
raise ValueError('`accept_new_attributes=True` requires lock_config=False')
397+
override_mode = _normalize_accept_new_attributes(
398+
accept_new_attributes, override_mode
399+
)
400+
if override_mode != OverrideMode.NEVER and lock_config:
401+
raise ValueError('`override_mode` requires `lock_config=False`')
352402
parser = _InlineConfigParser(name=name, lock_config=lock_config)
353403
flag = _ConfigFlag(
354404
parser=parser,
@@ -357,9 +407,10 @@ def DEFINE_config_dict( # pylint: disable=g-bad-name
357407
default=config,
358408
help_string=help_string,
359409
flag_values=flag_values,
360-
accept_new_attributes=accept_new_attributes,
410+
override_mode=override_mode,
361411
sys_argv=sys_argv,
362-
**kwargs)
412+
**kwargs,
413+
)
363414

364415
# Get the module name for the frame at depth 1 in the call stack.
365416
module_name = sys._getframe(1).f_globals.get('__name__', None) # pylint: disable=protected-access
@@ -747,14 +798,14 @@ def __init__(
747798
self,
748799
flag_values=FLAGS,
749800
*,
750-
accept_new_attributes: bool = False,
801+
override_mode: OverrideMode = OverrideMode.NEVER,
751802
sys_argv=None,
752803
**kwargs,
753804
):
754805
# Parent constructor can already call .Parse, thus additional fields
755806
# have to be set here.
756807
self.flag_values = flag_values
757-
self._accept_new_attributes = accept_new_attributes
808+
self._override_mode = override_mode
758809
# Note, we don't replace sys_argv with sys.argv here if it's None because
759810
# in some obscure multiprocessing use cases, sys.argv may not be populated
760811
# until later and we need to look it up at parse time.
@@ -839,7 +890,7 @@ def _parse(self, argument):
839890
self._initialize_missing_parent_fields(config, overrides)
840891
self._validate_overrides(config, overrides)
841892

842-
if self._accept_new_attributes:
893+
if self._override_mode != OverrideMode.NEVER:
843894
# If user provide a new attribute, fallback to `object` to accept all
844895
# literal
845896
default_type = object
@@ -855,7 +906,9 @@ def _parse(self, argument):
855906
field_name = '{}.{}'.format(self.name, field_path)
856907

857908
parser = None
858-
if field_type in _FIELD_TYPE_TO_PARSER:
909+
if self._override_mode == OverrideMode.ALWAYS:
910+
parser = _LiteralParser()
911+
elif field_type in _FIELD_TYPE_TO_PARSER:
859912
parser = _FIELD_TYPE_TO_PARSER[field_type]
860913
elif isinstance(field_type, type) and issubclass(
861914
field_type, config_dict.ConfigDict
@@ -903,7 +956,7 @@ def _parse(self, argument):
903956
serializer=serializer,
904957
name=field_name,
905958
default=default,
906-
accept_new_attributes=self._accept_new_attributes,
959+
override_mode=self._override_mode,
907960
help_string=field_help,
908961
)
909962
# Literal values support the `--my_bool` / `--nomy_bool` syntax
@@ -1073,7 +1126,7 @@ def __init__(
10731126
help_string: str,
10741127
short_name: Optional[str] = None,
10751128
boolean: bool = False,
1076-
accept_new_attributes: bool = False,
1129+
override_mode: OverrideMode = OverrideMode.NEVER,
10771130
):
10781131
"""Creates new flag with callback."""
10791132
super().__init__(
@@ -1087,15 +1140,23 @@ def __init__(
10871140
self._path = path
10881141
self._config = config
10891142
self._override_values = override_values
1090-
self._accept_new_attributes = accept_new_attributes
1143+
self._override_mode = override_mode
10911144

10921145
def parse(self, argument):
10931146
super().parse(argument)
10941147
# Callback to set value in ConfigDict.
1095-
config_path.set_value(
1096-
self._path, self._config, self.value,
1097-
accept_new_attributes=self._accept_new_attributes,
1148+
ctx = (
1149+
self._config.ignore_type()
1150+
if self._override_mode == OverrideMode.ALWAYS
1151+
else contextlib.nullcontext()
10981152
)
1153+
with ctx:
1154+
config_path.set_value(
1155+
self._path,
1156+
self._config,
1157+
self.value,
1158+
accept_new_attributes=self._override_mode != OverrideMode.NEVER,
1159+
)
10991160
self._override_values[self._path] = self.value
11001161

11011162

ml_collections/config_flags/tests/config_overriding_test.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def _parse_flags(
6060
config=None,
6161
lock_config=True,
6262
required=False,
63-
accept_new_attributes=False,
63+
accept_new_attributes=None,
64+
override_mode=None,
6465
use_sys_argv_override=False,
6566
):
6667
"""Parses arguments simulating sys.argv or via sys_argv argument."""
@@ -87,6 +88,7 @@ def _parse_flags(
8788
flag_values=values,
8889
lock_config=lock_config,
8990
accept_new_attributes=accept_new_attributes,
91+
override_mode=override_mode,
9092
sys_argv=(argv if use_sys_argv_override else None),
9193
)
9294
else:
@@ -95,7 +97,10 @@ def _parse_flags(
9597
config=config,
9698
flag_values=values,
9799
lock_config=lock_config,
98-
sys_argv=(argv if use_sys_argv_override else None))
100+
accept_new_attributes=accept_new_attributes,
101+
override_mode=override_mode,
102+
sys_argv=(argv if use_sys_argv_override else None),
103+
)
99104

100105
if required:
101106
flags.mark_flag_as_required('test_config', flag_values=values)
@@ -726,6 +731,63 @@ def testNewAttributes(self):
726731
self.assertEqual(cfg.new_value, config_dict.ConfigDict({'a': [1, 2, 3]}))
727732
self.assertEqual(cfg.new_bool_value, True)
728733

734+
def testUseLiteralParser(self):
735+
config = config_dict.ConfigDict({
736+
'integer': 1,
737+
'float': 2.0,
738+
'string': 'hello',
739+
'nested': {
740+
'value': 10,
741+
}
742+
})
743+
values = _parse_flags(
744+
'./program'
745+
' --test_config.integer="hello"'
746+
' --test_config.float=3'
747+
' --test_config.string="world"'
748+
' --test_config.nested.value=99'
749+
' --test_config.new_value="new"',
750+
config=config,
751+
lock_config=False,
752+
override_mode=config_flags.OverrideMode.ALWAYS,
753+
)
754+
cfg = values.test_config
755+
self.assertEqual(cfg.integer, 'hello')
756+
self.assertEqual(cfg.float, 3)
757+
self.assertEqual(cfg.string, 'world')
758+
self.assertEqual(cfg.nested.value, 99)
759+
self.assertEqual(cfg.new_value, 'new')
760+
761+
def testOverrideModeNewAttributes(self):
762+
values = _parse_flags(
763+
'./program'
764+
f' --test_config={_LITERAL_CONFIG_FILE}'
765+
' --test_config.integer=123'
766+
' --test_config.other_new_value="abc def"'
767+
' --test_config.new_value="{\'a\': [1, 2, 3]}"'
768+
' --test_config.new_bool_value=true',
769+
override_mode=config_flags.OverrideMode.NEW_ONLY,
770+
lock_config=False,
771+
)
772+
cfg = values.test_config
773+
self.assertEqual(cfg.integer, 123)
774+
self.assertIsNone(cfg.string)
775+
self.assertEqual(cfg.other_new_value, 'abc def')
776+
self.assertEqual(cfg.new_value, config_dict.ConfigDict({'a': [1, 2, 3]}))
777+
self.assertEqual(cfg.new_bool_value, True)
778+
779+
def testOverrideModeBothParametersRaisesError(self):
780+
with self.assertRaisesRegex(
781+
ValueError,
782+
'Cannot specify both `accept_new_attributes` and `override_mode`.',
783+
):
784+
_parse_flags(
785+
f'./program --test_config={_LITERAL_CONFIG_FILE}',
786+
accept_new_attributes=True,
787+
override_mode=config_flags.OverrideMode.NEW_ONLY,
788+
lock_config=False,
789+
)
790+
729791

730792
def _simple_config():
731793
config = config_dict.ConfigDict()

0 commit comments

Comments
 (0)