Skip to content

Commit 5794912

Browse files
ConchylicultorThe ml_collections Authors
authored andcommitted
Always use Literal evals when parsing kauldron flags
PiperOrigin-RevId: 872335743
1 parent 91e91a5 commit 5794912

File tree

3 files changed

+159
-36
lines changed

3 files changed

+159
-36
lines changed

ml_collections/config_flags/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .config_flags import DEFINE_config_file
2020
from .config_flags import get_config_filename
2121
from .config_flags import get_override_values
22+
from .config_flags import OverrideMode
2223
from .config_flags import register_flag_parser
2324
from .config_flags import register_flag_parser_for_type
2425

@@ -28,6 +29,7 @@
2829
"DEFINE_config_file",
2930
"get_config_filename",
3031
"get_override_values",
32+
"OverrideMode",
3133
"register_flag_parser",
3234
"register_flag_parser_for_type",
3335
)

ml_collections/config_flags/config_flags.py

Lines changed: 93 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Configuration commmand line parser."""
1616

1717
import ast
18+
import contextlib
1819
import copy
1920
import dataclasses
2021
import enum
@@ -48,6 +49,40 @@
4849
flags.disclaim_key_flags()
4950

5051

52+
class OverrideMode(enum.Enum):
53+
"""Behavior to overwrite config parameters types.
54+
55+
Attributes:
56+
ALWAYS: All `--cfg.xxx` flags are parsed using `_LiteralParser`.
57+
NEW_ONLY: `--cfg.xxx` is only parsed using `_LiteralParser`, if `cfg.xxx`
58+
does not exist yet.
59+
NEVER: All `--cfg.xxx` flags are parsed using type-specific parsers (i.e. it
60+
is not possible to have flag supporting both int and str).
61+
"""
62+
63+
ALWAYS = enum.auto()
64+
NEW_ONLY = enum.auto()
65+
NEVER = enum.auto()
66+
67+
68+
def _normalize_accept_new_attributes(
69+
accept_new_attributes: bool | None = None,
70+
override_mode: OverrideMode | None = None,
71+
) -> OverrideMode:
72+
"""Normalizes override_mode."""
73+
if accept_new_attributes is not None and override_mode is not None:
74+
raise ValueError(
75+
'Cannot specify both `accept_new_attributes` and `override_mode`.'
76+
)
77+
if override_mode is not None:
78+
return override_mode
79+
if accept_new_attributes is None:
80+
return OverrideMode.NEVER
81+
if accept_new_attributes:
82+
return OverrideMode.NEW_ONLY
83+
return OverrideMode.NEVER
84+
85+
5186
def _load_source(module_name: str, module_path: str) -> types.ModuleType:
5287
"""Loads a Python module from its source file.
5388
@@ -145,9 +180,11 @@ def DEFINE_config_file( # pylint: disable=g-bad-name
145180
help_string: str = 'path to config file.',
146181
flag_values: flags.FlagValues = FLAGS,
147182
lock_config: bool = True,
148-
accept_new_attributes: bool = False,
183+
accept_new_attributes: bool | None = None,
184+
override_mode: OverrideMode | None = None,
149185
sys_argv: Optional[List[str]] = None,
150-
**kwargs) -> flags.FlagHolder:
186+
**kwargs,
187+
) -> flags.FlagHolder:
151188
r"""Defines flag for `ConfigDict` files compatible with absl flags.
152189
153190
The flag's value should be a path to a valid python file which contains a
@@ -247,9 +284,10 @@ def get_config(config_string):
247284
absl.flags.FLAGS)
248285
lock_config: If set to True, loaded config will be locked through calling
249286
.lock() method on its instance (if it exists). (default: True)
250-
accept_new_attributes: If `True`, accept to pass arbitrary attributes that
251-
are not originally defined in the `get_config()` dict.
252-
`accept_new_attributes` requires `lock_config=False`.
287+
accept_new_attributes: Deprecated, use `override_mode` instead. `False`:
288+
only existing attributes can be overridden. `True`: new attributes are
289+
accepted. Requires `lock_config=False`.
290+
override_mode: Controls how override flags are parsed.
253291
sys_argv: If set, interprets this as the full list of args used in parsing.
254292
This is used to identify which overrides to define as flags. If not
255293
specified, uses the system sys.argv to figure it out.
@@ -258,8 +296,12 @@ def get_config(config_string):
258296
Returns:
259297
a handle to defined flag.
260298
"""
261-
if accept_new_attributes and lock_config:
262-
raise ValueError('`accept_new_attributes=True` requires lock_config=False')
299+
300+
override_mode = _normalize_accept_new_attributes(
301+
accept_new_attributes, override_mode
302+
)
303+
if override_mode != OverrideMode.NEVER and lock_config:
304+
raise ValueError('`override_mode` requires `lock_config=False`')
263305
parser = ConfigFileFlagParser(name=name, lock_config=lock_config)
264306
serializer = flags.ArgumentSerializer()
265307
flag = _ConfigFlag(
@@ -269,9 +311,10 @@ def get_config(config_string):
269311
default=default,
270312
help_string=help_string,
271313
flag_values=flag_values,
272-
accept_new_attributes=accept_new_attributes,
314+
override_mode=override_mode,
273315
sys_argv=sys_argv,
274-
**kwargs)
316+
**kwargs,
317+
)
275318

276319
return flags.DEFINE_flag(flag, flag_values)
277320

@@ -282,9 +325,11 @@ def DEFINE_config_dict( # pylint: disable=g-bad-name
282325
help_string: str = 'ConfigDict instance.',
283326
flag_values: flags.FlagValues = FLAGS,
284327
lock_config: bool = True,
285-
accept_new_attributes: bool = False,
328+
accept_new_attributes: bool | None = None,
329+
override_mode: OverrideMode | None = None,
286330
sys_argv: Optional[List[str]] = None,
287-
**kwargs) -> flags.FlagHolder:
331+
**kwargs,
332+
) -> flags.FlagHolder:
288333
"""Defines flag for inline `ConfigDict's` compatible with absl flags.
289334
290335
Similar to `DEFINE_config_file` except the flag's value should be a
@@ -328,15 +373,16 @@ def DEFINE_config_dict( # pylint: disable=g-bad-name
328373
Args:
329374
name: Flag name.
330375
config: `ConfigDict` object.
331-
help_string: Help string to display when --helpfull is called.
332-
(default: "ConfigDict instance.")
333-
flag_values: FlagValues instance used for parsing.
334-
(default: absl.flags.FLAGS)
376+
help_string: Help string to display when --helpfull is called. (default:
377+
"ConfigDict instance.")
378+
flag_values: FlagValues instance used for parsing. (default:
379+
absl.flags.FLAGS)
335380
lock_config: If set to True, loaded config will be locked through calling
336-
.lock() method on its instance (if it exists). (default: True)
337-
accept_new_attributes: If `True`, accept to pass arbitrary attributes that
338-
are not originally defined in the `config` argument.
339-
`accept_new_attributes` requires `lock_config=False`.
381+
.lock() method on its instance (if it exists). (default: True)
382+
accept_new_attributes: Deprecated, use `override_mode` instead. `False`:
383+
only existing attributes can be overridden. `True`: new attributes are
384+
accepted. Requires `lock_config=False`.
385+
override_mode: Controls how override flags are parsed.
340386
sys_argv: If set, interprets this as the full list of args used in parsing.
341387
This is used to identify which overrides to define as flags. If not
342388
specified, uses the system sys.argv to figure it out.
@@ -347,8 +393,11 @@ def DEFINE_config_dict( # pylint: disable=g-bad-name
347393
"""
348394
if not isinstance(config, config_dict.ConfigDict):
349395
raise TypeError('config should be a ConfigDict')
350-
if accept_new_attributes and lock_config:
351-
raise ValueError('`accept_new_attributes=True` requires lock_config=False')
396+
override_mode = _normalize_accept_new_attributes(
397+
accept_new_attributes, override_mode
398+
)
399+
if override_mode != OverrideMode.NEVER and lock_config:
400+
raise ValueError('`override_mode` requires `lock_config=False`')
352401
parser = _InlineConfigParser(name=name, lock_config=lock_config)
353402
flag = _ConfigFlag(
354403
parser=parser,
@@ -357,9 +406,10 @@ def DEFINE_config_dict( # pylint: disable=g-bad-name
357406
default=config,
358407
help_string=help_string,
359408
flag_values=flag_values,
360-
accept_new_attributes=accept_new_attributes,
409+
override_mode=override_mode,
361410
sys_argv=sys_argv,
362-
**kwargs)
411+
**kwargs,
412+
)
363413

364414
# Get the module name for the frame at depth 1 in the call stack.
365415
module_name = sys._getframe(1).f_globals.get('__name__', None) # pylint: disable=protected-access
@@ -747,14 +797,14 @@ def __init__(
747797
self,
748798
flag_values=FLAGS,
749799
*,
750-
accept_new_attributes: bool = False,
800+
override_mode: OverrideMode = OverrideMode.NEVER,
751801
sys_argv=None,
752802
**kwargs,
753803
):
754804
# Parent constructor can already call .Parse, thus additional fields
755805
# have to be set here.
756806
self.flag_values = flag_values
757-
self._accept_new_attributes = accept_new_attributes
807+
self._override_mode = override_mode
758808
# Note, we don't replace sys_argv with sys.argv here if it's None because
759809
# in some obscure multiprocessing use cases, sys.argv may not be populated
760810
# until later and we need to look it up at parse time.
@@ -839,7 +889,7 @@ def _parse(self, argument):
839889
self._initialize_missing_parent_fields(config, overrides)
840890
self._validate_overrides(config, overrides)
841891

842-
if self._accept_new_attributes:
892+
if self._override_mode != OverrideMode.NEVER:
843893
# If user provide a new attribute, fallback to `object` to accept all
844894
# literal
845895
default_type = object
@@ -855,7 +905,9 @@ def _parse(self, argument):
855905
field_name = '{}.{}'.format(self.name, field_path)
856906

857907
parser = None
858-
if field_type in _FIELD_TYPE_TO_PARSER:
908+
if self._override_mode == OverrideMode.ALWAYS:
909+
parser = _LiteralParser()
910+
elif field_type in _FIELD_TYPE_TO_PARSER:
859911
parser = _FIELD_TYPE_TO_PARSER[field_type]
860912
elif isinstance(field_type, type) and issubclass(
861913
field_type, config_dict.ConfigDict
@@ -870,7 +922,6 @@ def _parse(self, argument):
870922
elif isinstance(field_type, type) and issubclass(field_type, enum.Enum):
871923
parser = flags.EnumClassParser(field_type, case_sensitive=False)
872924
elif dataclasses.is_dataclass(field_type):
873-
# For dataclasses-valued fields allow default instance creation.
874925
is_optional = config_path.is_optional(field_path, config)
875926
parser = _DataclassParser(
876927
name=field_path, dataclass_type=field_type,
@@ -903,7 +954,7 @@ def _parse(self, argument):
903954
serializer=serializer,
904955
name=field_name,
905956
default=default,
906-
accept_new_attributes=self._accept_new_attributes,
957+
override_mode=self._override_mode,
907958
help_string=field_help,
908959
)
909960
# Literal values support the `--my_bool` / `--nomy_bool` syntax
@@ -1073,7 +1124,7 @@ def __init__(
10731124
help_string: str,
10741125
short_name: Optional[str] = None,
10751126
boolean: bool = False,
1076-
accept_new_attributes: bool = False,
1127+
override_mode: OverrideMode = OverrideMode.NEVER,
10771128
):
10781129
"""Creates new flag with callback."""
10791130
super().__init__(
@@ -1087,15 +1138,23 @@ def __init__(
10871138
self._path = path
10881139
self._config = config
10891140
self._override_values = override_values
1090-
self._accept_new_attributes = accept_new_attributes
1141+
self._override_mode = override_mode
10911142

10921143
def parse(self, argument):
10931144
super().parse(argument)
10941145
# Callback to set value in ConfigDict.
1095-
config_path.set_value(
1096-
self._path, self._config, self.value,
1097-
accept_new_attributes=self._accept_new_attributes,
1146+
ctx = (
1147+
self._config.ignore_type()
1148+
if self._override_mode == OverrideMode.ALWAYS
1149+
else contextlib.nullcontext()
10981150
)
1151+
with ctx:
1152+
config_path.set_value(
1153+
self._path,
1154+
self._config,
1155+
self.value,
1156+
accept_new_attributes=self._override_mode != OverrideMode.NEVER,
1157+
)
10991158
self._override_values[self._path] = self.value
11001159

11011160

ml_collections/config_flags/tests/config_overriding_test.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def _parse_flags(
6060
config=None,
6161
lock_config=True,
6262
required=False,
63-
accept_new_attributes=False,
63+
accept_new_attributes=None,
64+
override_mode=None,
6465
use_sys_argv_override=False,
6566
):
6667
"""Parses arguments simulating sys.argv or via sys_argv argument."""
@@ -87,6 +88,7 @@ def _parse_flags(
8788
flag_values=values,
8889
lock_config=lock_config,
8990
accept_new_attributes=accept_new_attributes,
91+
override_mode=override_mode,
9092
sys_argv=(argv if use_sys_argv_override else None),
9193
)
9294
else:
@@ -95,7 +97,10 @@ def _parse_flags(
9597
config=config,
9698
flag_values=values,
9799
lock_config=lock_config,
98-
sys_argv=(argv if use_sys_argv_override else None))
100+
accept_new_attributes=accept_new_attributes,
101+
override_mode=override_mode,
102+
sys_argv=(argv if use_sys_argv_override else None),
103+
)
99104

100105
if required:
101106
flags.mark_flag_as_required('test_config', flag_values=values)
@@ -726,6 +731,63 @@ def testNewAttributes(self):
726731
self.assertEqual(cfg.new_value, config_dict.ConfigDict({'a': [1, 2, 3]}))
727732
self.assertEqual(cfg.new_bool_value, True)
728733

734+
def testUseLiteralParser(self):
735+
config = config_dict.ConfigDict({
736+
'integer': 1,
737+
'float': 2.0,
738+
'string': 'hello',
739+
'nested': {
740+
'value': 10,
741+
}
742+
})
743+
values = _parse_flags(
744+
'./program'
745+
' --test_config.integer="hello"'
746+
' --test_config.float=3'
747+
' --test_config.string="world"'
748+
' --test_config.nested.value=99'
749+
' --test_config.new_value="new"',
750+
config=config,
751+
lock_config=False,
752+
override_mode=config_flags.OverrideMode.ALWAYS,
753+
)
754+
cfg = values.test_config
755+
self.assertEqual(cfg.integer, 'hello')
756+
self.assertEqual(cfg.float, 3)
757+
self.assertEqual(cfg.string, 'world')
758+
self.assertEqual(cfg.nested.value, 99)
759+
self.assertEqual(cfg.new_value, 'new')
760+
761+
def testOverrideModeNewAttributes(self):
762+
values = _parse_flags(
763+
'./program'
764+
f' --test_config={_LITERAL_CONFIG_FILE}'
765+
' --test_config.integer=123'
766+
' --test_config.other_new_value="abc def"'
767+
' --test_config.new_value="{\'a\': [1, 2, 3]}"'
768+
' --test_config.new_bool_value=true',
769+
override_mode=config_flags.LiteralParserMode.DEFAULT_ONLY,
770+
lock_config=False,
771+
)
772+
cfg = values.test_config
773+
self.assertEqual(cfg.integer, 123)
774+
self.assertIsNone(cfg.string)
775+
self.assertEqual(cfg.other_new_value, 'abc def')
776+
self.assertEqual(cfg.new_value, config_dict.ConfigDict({'a': [1, 2, 3]}))
777+
self.assertEqual(cfg.new_bool_value, True)
778+
779+
def testOverrideModeBothParametersRaisesError(self):
780+
with self.assertRaisesRegex(
781+
ValueError,
782+
'Cannot specify both `accept_new_attributes` and `override_mode`.',
783+
):
784+
_parse_flags(
785+
f'./program --test_config={_LITERAL_CONFIG_FILE}',
786+
accept_new_attributes=True,
787+
override_mode=config_flags.OverrideMode.NEW_ONLY,
788+
lock_config=False,
789+
)
790+
729791

730792
def _simple_config():
731793
config = config_dict.ConfigDict()

0 commit comments

Comments
 (0)