Skip to content

Commit 40c5b43

Browse files
author
The ml_collections Authors
committed
Adding optional serializer to register_flag_parser and register_flag_parser_for_type.
This is necessary to make parser flags that are not compatible with flags.ArgumentSerializer work smoothly. PiperOrigin-RevId: 829081645
1 parent 95c20f4 commit 40c5b43

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

ml_collections/config_flags/config_flags.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Configuration commmand line parser."""
1616

1717
import ast
18+
import collections
1819
import copy
1920
import dataclasses
2021
import enum
@@ -116,6 +117,10 @@ def parse(self, argument: str) -> config_dict.ConfigDict:
116117
config_dict.ConfigDict: _ConfigDictParser(),
117118
object: _LiteralParser(),
118119
}
120+
# To serialize a new type, add a serializer to this dict.
121+
# By default flags.ArgumentSerializer will be used.
122+
_FIELD_TYPE_TO_SERIALIZER = {
123+
}
119124

120125

121126
class UnsupportedOperationError(flags.Error):
@@ -865,6 +870,12 @@ def _parse(self, argument):
865870
parse_fn=ft.partial(_MakeDefaultOrNone, field_type,
866871
allow_none=is_optional, field_path=field_path))
867872

873+
serializer = flags.ArgumentSerializer()
874+
if field_type in _FIELD_TYPE_TO_SERIALIZER:
875+
serializer = _FIELD_TYPE_TO_SERIALIZER[field_type]
876+
if field_type_origin in _FIELD_TYPE_TO_SERIALIZER:
877+
serializer = _FIELD_TYPE_TO_SERIALIZER[field_type_origin]
878+
868879
if parser:
869880
if not isinstance(parser, tuple_parser.TupleParser):
870881
if isinstance(parser, (_LiteralParser, _ConfigDictParser)):
@@ -877,12 +888,13 @@ def _parse(self, argument):
877888
default = None
878889
else:
879890
default = config_path.get_value(field_path, config)
891+
880892
flag = _ConfigFieldFlag(
881893
path=field_path,
882894
config=config,
883895
override_values=self._override_values,
884896
parser=parser,
885-
serializer=flags.ArgumentSerializer(),
897+
serializer=serializer,
886898
name=field_name,
887899
default=default,
888900
accept_new_attributes=self._accept_new_attributes,
@@ -904,7 +916,7 @@ def _parse(self, argument):
904916
config=config,
905917
override_values=self._override_values,
906918
parser=parser,
907-
serializer=flags.ArgumentSerializer(),
919+
serializer=serializer,
908920
name=field_name,
909921
default=config_path.get_value(field_path, config),
910922
help_string=field_help,
@@ -1126,23 +1138,30 @@ def _parse(self, arguments):
11261138

11271139

11281140
def register_flag_parser_for_type(
1129-
field_type: _T, parser: flags.ArgumentParser) -> _T:
1141+
field_type: _T,
1142+
parser: flags.ArgumentParser,
1143+
serializer=flags.ArgumentSerializer(),
1144+
) -> _T:
11301145
"""Registers parser for a given type.
11311146
11321147
See documentation for `register_flag_parser` for usage example.
11331148
11341149
Args:
11351150
field_type: field type to register
11361151
parser: parser to use
1152+
serializer: serializer to use
11371153
11381154
Returns:
11391155
field_type unmodified.
11401156
"""
11411157
_FIELD_TYPE_TO_PARSER[field_type] = parser
1158+
_FIELD_TYPE_TO_SERIALIZER[field_type] = serializer
11421159
return field_type
11431160

11441161

1145-
def register_flag_parser(*, parser: flags.ArgumentParser) -> Callable[[_T], _T]:
1162+
def register_flag_parser(
1163+
*, parser: flags.ArgumentParser, serializer=flags.ArgumentSerializer()
1164+
) -> Callable[[_T], _T]:
11461165
"""Creates a decorator to register parser on types.
11471166
11481167
For example:
@@ -1174,8 +1193,11 @@ class MainConfig:
11741193
11751194
Args:
11761195
parser: parser to use.
1196+
serializer: serializer to use.
11771197
11781198
Returns:
11791199
Decorator to apply to types.
11801200
"""
1181-
return ft.partial(register_flag_parser_for_type, parser=parser)
1201+
return ft.partial(
1202+
register_flag_parser_for_type, parser=parser, serializer=serializer
1203+
)

ml_collections/config_flags/tests/config_overriding_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ def testOverridesSerialize(self):
795795
' --test_config.type_int=10'
796796
' --test_config.type_str=str_commandline'
797797
' --test_config.type_tuple="(\'tuple_str\', 10)"'
798-
)
798+
)
799799
command_line += ' --test_config.type_ustr=ustr_commandline'
800800
values = _parse_flags(command_line, config=copy.copy(all_types_config))
801801

0 commit comments

Comments
 (0)