Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions ml_collections/config_flags/config_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ def parse(self, argument: str) -> config_dict.ConfigDict:
config_dict.ConfigDict: _ConfigDictParser(),
object: _LiteralParser(),
}
# To serialize a new type, add a serializer to this dict.
# By default flags.ArgumentSerializer will be used.
_FIELD_TYPE_TO_SERIALIZER = {
}


class UnsupportedOperationError(flags.Error):
Expand Down Expand Up @@ -865,6 +869,12 @@ def _parse(self, argument):
parse_fn=ft.partial(_MakeDefaultOrNone, field_type,
allow_none=is_optional, field_path=field_path))

serializer = flags.ArgumentSerializer()
if field_type in _FIELD_TYPE_TO_SERIALIZER:
serializer = _FIELD_TYPE_TO_SERIALIZER[field_type]
if field_type_origin in _FIELD_TYPE_TO_SERIALIZER:
serializer = _FIELD_TYPE_TO_SERIALIZER[field_type_origin]

if parser:
if not isinstance(parser, tuple_parser.TupleParser):
if isinstance(parser, (_LiteralParser, _ConfigDictParser)):
Expand All @@ -882,7 +892,7 @@ def _parse(self, argument):
config=config,
override_values=self._override_values,
parser=parser,
serializer=flags.ArgumentSerializer(),
serializer=serializer,
name=field_name,
default=default,
accept_new_attributes=self._accept_new_attributes,
Expand All @@ -904,7 +914,7 @@ def _parse(self, argument):
config=config,
override_values=self._override_values,
parser=parser,
serializer=flags.ArgumentSerializer(),
serializer=serializer,
name=field_name,
default=config_path.get_value(field_path, config),
help_string=field_help,
Expand Down Expand Up @@ -1126,23 +1136,32 @@ def _parse(self, arguments):


def register_flag_parser_for_type(
field_type: _T, parser: flags.ArgumentParser) -> _T:
field_type: _T,
parser: flags.ArgumentParser,
serializer: flags.ArgumentSerializer = flags.ArgumentSerializer(),
) -> _T:
"""Registers parser for a given type.

See documentation for `register_flag_parser` for usage example.

Args:
field_type: field type to register
parser: parser to use
serializer: serializer to use

Returns:
field_type unmodified.
"""
_FIELD_TYPE_TO_PARSER[field_type] = parser
_FIELD_TYPE_TO_SERIALIZER[field_type] = serializer
return field_type


def register_flag_parser(*, parser: flags.ArgumentParser) -> Callable[[_T], _T]:
def register_flag_parser(
*,
parser: flags.ArgumentParser,
serializer: flags.ArgumentSerializer = flags.ArgumentSerializer(),
) -> Callable[[_T], _T]:
"""Creates a decorator to register parser on types.

For example:
Expand Down Expand Up @@ -1174,8 +1193,11 @@ class MainConfig:

Args:
parser: parser to use.
serializer: serializer to use.

Returns:
Decorator to apply to types.
"""
return ft.partial(register_flag_parser_for_type, parser=parser)
return ft.partial(
register_flag_parser_for_type, parser=parser, serializer=serializer
)
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def testOverridesSerialize(self):
' --test_config.type_int=10'
' --test_config.type_str=str_commandline'
' --test_config.type_tuple="(\'tuple_str\', 10)"'
)
)
command_line += ' --test_config.type_ustr=ustr_commandline'
values = _parse_flags(command_line, config=copy.copy(all_types_config))

Expand Down
Loading