diff --git a/ml_collections/config_flags/config_flags.py b/ml_collections/config_flags/config_flags.py index becd743..faf62a8 100644 --- a/ml_collections/config_flags/config_flags.py +++ b/ml_collections/config_flags/config_flags.py @@ -116,6 +116,10 @@ def parse(self, argument: str) -> config_dict.ConfigDict: config_dict.ConfigDict: _ConfigDictParser(), object: _LiteralParser(), } +# To serialize a new type, add a serializer to this dict. +# By default flags.ArgumentSerializer will be used. +_FIELD_TYPE_TO_SERIALIZER = { +} class UnsupportedOperationError(flags.Error): @@ -865,6 +869,12 @@ def _parse(self, argument): parse_fn=ft.partial(_MakeDefaultOrNone, field_type, allow_none=is_optional, field_path=field_path)) + serializer = flags.ArgumentSerializer() + if field_type in _FIELD_TYPE_TO_SERIALIZER: + serializer = _FIELD_TYPE_TO_SERIALIZER[field_type] + if field_type_origin in _FIELD_TYPE_TO_SERIALIZER: + serializer = _FIELD_TYPE_TO_SERIALIZER[field_type_origin] + if parser: if not isinstance(parser, tuple_parser.TupleParser): if isinstance(parser, (_LiteralParser, _ConfigDictParser)): @@ -882,7 +892,7 @@ def _parse(self, argument): config=config, override_values=self._override_values, parser=parser, - serializer=flags.ArgumentSerializer(), + serializer=serializer, name=field_name, default=default, accept_new_attributes=self._accept_new_attributes, @@ -904,7 +914,7 @@ def _parse(self, argument): config=config, override_values=self._override_values, parser=parser, - serializer=flags.ArgumentSerializer(), + serializer=serializer, name=field_name, default=config_path.get_value(field_path, config), help_string=field_help, @@ -1126,7 +1136,10 @@ def _parse(self, arguments): def register_flag_parser_for_type( - field_type: _T, parser: flags.ArgumentParser) -> _T: + field_type: _T, + parser: flags.ArgumentParser, + serializer: flags.ArgumentSerializer = flags.ArgumentSerializer(), +) -> _T: """Registers parser for a given type. See documentation for `register_flag_parser` for usage example. @@ -1134,15 +1147,21 @@ def register_flag_parser_for_type( Args: field_type: field type to register parser: parser to use + serializer: serializer to use Returns: field_type unmodified. """ _FIELD_TYPE_TO_PARSER[field_type] = parser + _FIELD_TYPE_TO_SERIALIZER[field_type] = serializer return field_type -def register_flag_parser(*, parser: flags.ArgumentParser) -> Callable[[_T], _T]: +def register_flag_parser( + *, + parser: flags.ArgumentParser, + serializer: flags.ArgumentSerializer = flags.ArgumentSerializer(), +) -> Callable[[_T], _T]: """Creates a decorator to register parser on types. For example: @@ -1174,8 +1193,11 @@ class MainConfig: Args: parser: parser to use. + serializer: serializer to use. Returns: Decorator to apply to types. """ - return ft.partial(register_flag_parser_for_type, parser=parser) + return ft.partial( + register_flag_parser_for_type, parser=parser, serializer=serializer + ) diff --git a/ml_collections/config_flags/tests/config_overriding_test.py b/ml_collections/config_flags/tests/config_overriding_test.py index ce5e696..2d5a62d 100644 --- a/ml_collections/config_flags/tests/config_overriding_test.py +++ b/ml_collections/config_flags/tests/config_overriding_test.py @@ -795,7 +795,7 @@ def testOverridesSerialize(self): ' --test_config.type_int=10' ' --test_config.type_str=str_commandline' ' --test_config.type_tuple="(\'tuple_str\', 10)"' - ) + ) command_line += ' --test_config.type_ustr=ustr_commandline' values = _parse_flags(command_line, config=copy.copy(all_types_config))