Skip to content

Commit 7f28e7c

Browse files
author
The ml_collections Authors
committed
Adding optional serializer to register_flag_parser and register_flag_parser_for_type.
This is necessary to make parser flags that are not compatible with flags.ArgumentSerializer work smoothly. PiperOrigin-RevId: 829081645
1 parent 95c20f4 commit 7f28e7c

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

ml_collections/config_flags/config_flags.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ def parse(self, argument: str) -> config_dict.ConfigDict:
116116
config_dict.ConfigDict: _ConfigDictParser(),
117117
object: _LiteralParser(),
118118
}
119+
# To serialize a new type, add a serializer to this dict.
120+
# By default flags.ArgumentSerializer will be used.
121+
_FIELD_TYPE_TO_SERIALIZER = {
122+
}
119123

120124

121125
class UnsupportedOperationError(flags.Error):
@@ -865,6 +869,12 @@ def _parse(self, argument):
865869
parse_fn=ft.partial(_MakeDefaultOrNone, field_type,
866870
allow_none=is_optional, field_path=field_path))
867871

872+
serializer = flags.ArgumentSerializer()
873+
if field_type in _FIELD_TYPE_TO_SERIALIZER:
874+
serializer = _FIELD_TYPE_TO_SERIALIZER[field_type]
875+
if field_type_origin in _FIELD_TYPE_TO_SERIALIZER:
876+
serializer = _FIELD_TYPE_TO_SERIALIZER[field_type_origin]
877+
868878
if parser:
869879
if not isinstance(parser, tuple_parser.TupleParser):
870880
if isinstance(parser, (_LiteralParser, _ConfigDictParser)):
@@ -882,7 +892,7 @@ def _parse(self, argument):
882892
config=config,
883893
override_values=self._override_values,
884894
parser=parser,
885-
serializer=flags.ArgumentSerializer(),
895+
serializer=serializer,
886896
name=field_name,
887897
default=default,
888898
accept_new_attributes=self._accept_new_attributes,
@@ -904,7 +914,7 @@ def _parse(self, argument):
904914
config=config,
905915
override_values=self._override_values,
906916
parser=parser,
907-
serializer=flags.ArgumentSerializer(),
917+
serializer=serializer,
908918
name=field_name,
909919
default=config_path.get_value(field_path, config),
910920
help_string=field_help,
@@ -1126,23 +1136,32 @@ def _parse(self, arguments):
11261136

11271137

11281138
def register_flag_parser_for_type(
1129-
field_type: _T, parser: flags.ArgumentParser) -> _T:
1139+
field_type: _T,
1140+
parser: flags.ArgumentParser,
1141+
serializer: flags.ArgumentSerializer = flags.ArgumentSerializer(),
1142+
) -> _T:
11301143
"""Registers parser for a given type.
11311144
11321145
See documentation for `register_flag_parser` for usage example.
11331146
11341147
Args:
11351148
field_type: field type to register
11361149
parser: parser to use
1150+
serializer: serializer to use
11371151
11381152
Returns:
11391153
field_type unmodified.
11401154
"""
11411155
_FIELD_TYPE_TO_PARSER[field_type] = parser
1156+
_FIELD_TYPE_TO_SERIALIZER[field_type] = serializer
11421157
return field_type
11431158

11441159

1145-
def register_flag_parser(*, parser: flags.ArgumentParser) -> Callable[[_T], _T]:
1160+
def register_flag_parser(
1161+
*,
1162+
parser: flags.ArgumentParser,
1163+
serializer: flags.ArgumentSerializer = flags.ArgumentSerializer(),
1164+
) -> Callable[[_T], _T]:
11461165
"""Creates a decorator to register parser on types.
11471166
11481167
For example:
@@ -1174,8 +1193,11 @@ class MainConfig:
11741193
11751194
Args:
11761195
parser: parser to use.
1196+
serializer: serializer to use.
11771197
11781198
Returns:
11791199
Decorator to apply to types.
11801200
"""
1181-
return ft.partial(register_flag_parser_for_type, parser=parser)
1201+
return ft.partial(
1202+
register_flag_parser_for_type, parser=parser, serializer=serializer
1203+
)

ml_collections/config_flags/tests/config_overriding_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ def testOverridesSerialize(self):
795795
' --test_config.type_int=10'
796796
' --test_config.type_str=str_commandline'
797797
' --test_config.type_tuple="(\'tuple_str\', 10)"'
798-
)
798+
)
799799
command_line += ' --test_config.type_ustr=ustr_commandline'
800800
values = _parse_flags(command_line, config=copy.copy(all_types_config))
801801

0 commit comments

Comments
 (0)