From c044099500f645de12bec0c51a4adb6731f674e3 Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Sun, 28 Sep 2025 12:04:45 -0600 Subject: [PATCH] Support for enum kebab case. --- docs/index.md | 4 +- pydantic_settings/main.py | 8 +-- pydantic_settings/sources/providers/cli.py | 74 +++++++++++++++------- tests/test_source_cli.py | 33 ++++++++++ 4 files changed, 92 insertions(+), 27 deletions(-) diff --git a/docs/index.md b/docs/index.md index e5b39fa1..2867adaf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1354,7 +1354,9 @@ print(Settings().model_dump()) #### CLI Kebab Case for Arguments -Change whether CLI arguments should use kebab case by enabling `cli_kebab_case`. +Change whether CLI arguments should use kebab case by enabling `cli_kebab_case`. By default, `cli_kebab_case=True` will +ignore enum fields, and is equivalent to `cli_kebab_case='no_enums'`. To apply kebab case to everything, including +enums, use `cli_kebab_case='all'`. ```py import sys diff --git a/pydantic_settings/main.py b/pydantic_settings/main.py index f72e950f..e06f4831 100644 --- a/pydantic_settings/main.py +++ b/pydantic_settings/main.py @@ -7,7 +7,7 @@ from argparse import Namespace from collections.abc import Mapping from types import SimpleNamespace -from typing import Any, ClassVar, TypeVar +from typing import Any, ClassVar, Literal, TypeVar from pydantic import ConfigDict from pydantic._internal._config import config_keys @@ -62,7 +62,7 @@ class SettingsConfigDict(ConfigDict, total=False): cli_flag_prefix_char: str cli_implicit_flags: bool | None cli_ignore_unknown_args: bool | None - cli_kebab_case: bool | None + cli_kebab_case: bool | Literal['all', 'no_enums'] | None cli_shortcuts: Mapping[str, str | list[str]] | None secrets_dir: PathType | None json_file: PathType | None @@ -185,7 +185,7 @@ def __init__( _cli_flag_prefix_char: str | None = None, _cli_implicit_flags: bool | None = None, _cli_ignore_unknown_args: bool | None = None, - _cli_kebab_case: bool | None = None, + _cli_kebab_case: bool | Literal['all', 'no_enums'] | None = None, _cli_shortcuts: Mapping[str, str | list[str]] | None = None, _secrets_dir: PathType | None = None, **values: Any, @@ -272,7 +272,7 @@ def _settings_build_values( _cli_flag_prefix_char: str | None = None, _cli_implicit_flags: bool | None = None, _cli_ignore_unknown_args: bool | None = None, - _cli_kebab_case: bool | None = None, + _cli_kebab_case: bool | Literal['all', 'no_enums'] | None = None, _cli_shortcuts: Mapping[str, str | list[str]] | None = None, _secrets_dir: PathType | None = None, ) -> dict[str, Any]: diff --git a/pydantic_settings/sources/providers/cli.py b/pydantic_settings/sources/providers/cli.py index 87d0d58a..5248d8be 100644 --- a/pydantic_settings/sources/providers/cli.py +++ b/pydantic_settings/sources/providers/cli.py @@ -27,6 +27,7 @@ Any, Callable, Generic, + Literal, NoReturn, Optional, TypeVar, @@ -94,7 +95,7 @@ class _CliArg(BaseModel): arg_prefix: str case_sensitive: bool hide_none_type: bool - kebab_case: bool + kebab_case: Optional[Union[bool, Literal['all', 'no_enums']]] enable_decoding: Optional[bool] env_prefix_len: int args: list[str] = [] @@ -131,8 +132,20 @@ def __init__( parser_map[self.field_info][index] = parser_map[alias_path_dest][index] @classmethod - def get_kebab_case(cls, name: str, kebab_case: Optional[bool]) -> str: - return name.replace('_', '-') if kebab_case else name + def get_kebab_case(cls, name: str, kebab_case: Optional[Union[bool, Literal['all', 'no_enums']]]) -> str: + return name.replace('_', '-') if kebab_case not in (None, False) else name + + @classmethod + def get_enum_names( + cls, annotation: type[Any], kebab_case: Optional[Union[bool, Literal['all', 'no_enums']]] + ) -> tuple[str, ...]: + enum_names: tuple[str, ...] = () + annotation = _strip_annotated(annotation) + for type_ in get_args(annotation): + enum_names += cls.get_enum_names(type_, kebab_case) + if annotation and _lenient_issubclass(annotation, Enum): + enum_names += tuple(cls.get_kebab_case(val.name, kebab_case == 'all') for val in annotation) + return enum_names def subcommand_alias(self, sub_model: type[BaseModel]) -> str: return self.get_kebab_case( @@ -294,7 +307,7 @@ def __init__( cli_flag_prefix_char: str | None = None, cli_implicit_flags: bool | None = None, cli_ignore_unknown_args: bool | None = None, - cli_kebab_case: bool | None = None, + cli_kebab_case: bool | Literal['all', 'no_enums'] | None = None, cli_shortcuts: Mapping[str, str | list[str]] | None = None, case_sensitive: bool | None = True, root_parser: Any = None, @@ -490,23 +503,7 @@ def _load_env_vars( if isinstance(parsed_args, (Namespace, SimpleNamespace)): parsed_args = vars(parsed_args) - selected_subcommands: list[str] = [] - for field_name, val in list(parsed_args.items()): - if isinstance(val, list): - if self._is_nested_alias_path_only_workaround(parsed_args, field_name, val): - # Workaround for nested alias path environment variables not being handled. - # See https://github.com/pydantic/pydantic-settings/issues/670 - continue - - cli_arg = self._parser_map.get(field_name, {}).get(None) - if cli_arg and cli_arg.is_no_decode: - parsed_args[field_name] = ','.join(val) - continue - - parsed_args[field_name] = self._merge_parsed_list(val, field_name) - elif field_name.endswith(':subcommand') and val is not None: - selected_subcommands.append(self._parser_map[field_name][val].dest) - + selected_subcommands = self._resolve_parsed_args(parsed_args) for arg_dest, arg_map in self._parser_map.items(): if isinstance(arg_dest, str) and arg_dest.endswith(':subcommand'): for subcommand_dest in [arg.dest for arg in arg_map.values()]: @@ -534,6 +531,37 @@ def _load_env_vars( return self + def _resolve_parsed_args(self, parsed_args: dict[str, list[str] | str]) -> list[str]: + selected_subcommands: list[str] = [] + for field_name, val in list(parsed_args.items()): + if isinstance(val, list): + if self._is_nested_alias_path_only_workaround(parsed_args, field_name, val): + # Workaround for nested alias path environment variables not being handled. + # See https://github.com/pydantic/pydantic-settings/issues/670 + continue + + cli_arg = self._parser_map.get(field_name, {}).get(None) + if cli_arg and cli_arg.is_no_decode: + parsed_args[field_name] = ','.join(val) + continue + + parsed_args[field_name] = self._merge_parsed_list(val, field_name) + elif field_name.endswith(':subcommand') and val is not None: + selected_subcommands.append(self._parser_map[field_name][val].dest) + elif self.cli_kebab_case == 'all': + snake_val = val.replace('-', '_') + cli_arg = self._parser_map.get(field_name, {}).get(None) + if ( + cli_arg + and cli_arg.field_info.annotation + and (snake_val in cli_arg.get_enum_names(cli_arg.field_info.annotation, False)) + ): + if '_' in val: + raise ValueError(f'Input should be kebab-case "{val.replace("_", "-")}", not "{val}"') + parsed_args[field_name] = snake_val + + return selected_subcommands + def _is_nested_alias_path_only_workaround( self, parsed_args: dict[str, list[str] | str], field_name: str, val: list[str] ) -> bool: @@ -1198,7 +1226,9 @@ def _metavar_format_recurse(self, obj: Any) -> str: elif typing_objects.is_literal(origin): return self._metavar_format_choices(list(map(str, self._get_modified_args(obj)))) elif _lenient_issubclass(obj, Enum): - return self._metavar_format_choices([val.name for val in obj]) + return self._metavar_format_choices( + [_CliArg.get_kebab_case(val.name, self.cli_kebab_case == 'all') for val in obj] + ) elif isinstance(obj, _WithArgsTypes): return self._metavar_format_choices( list(map(self._metavar_format_recurse, self._get_modified_args(obj))), diff --git a/tests/test_source_cli.py b/tests/test_source_cli.py index 1fe98f9c..9770c625 100644 --- a/tests/test_source_cli.py +++ b/tests/test_source_cli.py @@ -2589,6 +2589,39 @@ class Root(BaseModel): ) +def test_cli_kebab_case_enums(): + class Example1(IntEnum): + example_a = 0 + example_b = 1 + + class Example2(IntEnum): + example_c = 2 + example_d = 3 + + class SettingsNoEnum(BaseSettings): + model_config = SettingsConfigDict(cli_kebab_case='no_enums') + example: Union[Example1, Example2] + mybool: bool + + class SettingsAll(BaseSettings): + model_config = SettingsConfigDict(cli_kebab_case='all') + example: Union[Example1, Example2] + mybool: bool + + assert CliApp.run( + SettingsNoEnum, + cli_args=['--example', 'example_a', '--mybool=true'], + ).model_dump() == {'example': Example1.example_a, 'mybool': True} + + assert CliApp.run(SettingsAll, cli_args=['--example', 'example-c', '--mybool=true']).model_dump() == { + 'example': Example2.example_c, + 'mybool': True, + } + + with pytest.raises(ValueError, match='Input should be kebab-case "example-a", not "example_a"'): + CliApp.run(SettingsAll, cli_args=['--example', 'example_a', '--mybool=true']) + + def test_cli_with_unbalanced_brackets_in_json_string(): class StrToStrDictOptions(BaseSettings): nested: dict[str, str]