Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,9 @@ print(Settings().model_dump())

#### CLI Kebab Case for Arguments

Change whether CLI arguments should use kebab case by enabling `cli_kebab_case`.
Change whether CLI arguments should use kebab case by enabling `cli_kebab_case`. By default, `cli_kebab_case=True` will
ignore enum fields, and is equivalent to `cli_kebab_case='no_enums'`. To apply kebab case to everything, including
enums, use `cli_kebab_case='all'`.

```py
import sys
Expand Down
8 changes: 4 additions & 4 deletions pydantic_settings/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from argparse import Namespace
from collections.abc import Mapping
from types import SimpleNamespace
from typing import Any, ClassVar, TypeVar
from typing import Any, ClassVar, Literal, TypeVar

from pydantic import ConfigDict
from pydantic._internal._config import config_keys
Expand Down Expand Up @@ -62,7 +62,7 @@ class SettingsConfigDict(ConfigDict, total=False):
cli_flag_prefix_char: str
cli_implicit_flags: bool | None
cli_ignore_unknown_args: bool | None
cli_kebab_case: bool | None
cli_kebab_case: bool | Literal['all', 'no_enums'] | None
cli_shortcuts: Mapping[str, str | list[str]] | None
secrets_dir: PathType | None
json_file: PathType | None
Expand Down Expand Up @@ -185,7 +185,7 @@ def __init__(
_cli_flag_prefix_char: str | None = None,
_cli_implicit_flags: bool | None = None,
_cli_ignore_unknown_args: bool | None = None,
_cli_kebab_case: bool | None = None,
_cli_kebab_case: bool | Literal['all', 'no_enums'] | None = None,
_cli_shortcuts: Mapping[str, str | list[str]] | None = None,
_secrets_dir: PathType | None = None,
**values: Any,
Expand Down Expand Up @@ -272,7 +272,7 @@ def _settings_build_values(
_cli_flag_prefix_char: str | None = None,
_cli_implicit_flags: bool | None = None,
_cli_ignore_unknown_args: bool | None = None,
_cli_kebab_case: bool | None = None,
_cli_kebab_case: bool | Literal['all', 'no_enums'] | None = None,
_cli_shortcuts: Mapping[str, str | list[str]] | None = None,
_secrets_dir: PathType | None = None,
) -> dict[str, Any]:
Expand Down
74 changes: 52 additions & 22 deletions pydantic_settings/sources/providers/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Any,
Callable,
Generic,
Literal,
NoReturn,
Optional,
TypeVar,
Expand Down Expand Up @@ -94,7 +95,7 @@ class _CliArg(BaseModel):
arg_prefix: str
case_sensitive: bool
hide_none_type: bool
kebab_case: bool
kebab_case: Optional[Union[bool, Literal['all', 'no_enums']]]
enable_decoding: Optional[bool]
env_prefix_len: int
args: list[str] = []
Expand Down Expand Up @@ -131,8 +132,20 @@ def __init__(
parser_map[self.field_info][index] = parser_map[alias_path_dest][index]

@classmethod
def get_kebab_case(cls, name: str, kebab_case: Optional[bool]) -> str:
return name.replace('_', '-') if kebab_case else name
def get_kebab_case(cls, name: str, kebab_case: Optional[Union[bool, Literal['all', 'no_enums']]]) -> str:
return name.replace('_', '-') if kebab_case not in (None, False) else name

@classmethod
def get_enum_names(
cls, annotation: type[Any], kebab_case: Optional[Union[bool, Literal['all', 'no_enums']]]
) -> tuple[str, ...]:
enum_names: tuple[str, ...] = ()
annotation = _strip_annotated(annotation)
for type_ in get_args(annotation):
enum_names += cls.get_enum_names(type_, kebab_case)
if annotation and _lenient_issubclass(annotation, Enum):
enum_names += tuple(cls.get_kebab_case(val.name, kebab_case == 'all') for val in annotation)
return enum_names

def subcommand_alias(self, sub_model: type[BaseModel]) -> str:
return self.get_kebab_case(
Expand Down Expand Up @@ -294,7 +307,7 @@ def __init__(
cli_flag_prefix_char: str | None = None,
cli_implicit_flags: bool | None = None,
cli_ignore_unknown_args: bool | None = None,
cli_kebab_case: bool | None = None,
cli_kebab_case: bool | Literal['all', 'no_enums'] | None = None,
cli_shortcuts: Mapping[str, str | list[str]] | None = None,
case_sensitive: bool | None = True,
root_parser: Any = None,
Expand Down Expand Up @@ -490,23 +503,7 @@ def _load_env_vars(
if isinstance(parsed_args, (Namespace, SimpleNamespace)):
parsed_args = vars(parsed_args)

selected_subcommands: list[str] = []
for field_name, val in list(parsed_args.items()):
if isinstance(val, list):
if self._is_nested_alias_path_only_workaround(parsed_args, field_name, val):
# Workaround for nested alias path environment variables not being handled.
# See https://github.com/pydantic/pydantic-settings/issues/670
continue

cli_arg = self._parser_map.get(field_name, {}).get(None)
if cli_arg and cli_arg.is_no_decode:
parsed_args[field_name] = ','.join(val)
continue

parsed_args[field_name] = self._merge_parsed_list(val, field_name)
elif field_name.endswith(':subcommand') and val is not None:
selected_subcommands.append(self._parser_map[field_name][val].dest)

selected_subcommands = self._resolve_parsed_args(parsed_args)
for arg_dest, arg_map in self._parser_map.items():
if isinstance(arg_dest, str) and arg_dest.endswith(':subcommand'):
for subcommand_dest in [arg.dest for arg in arg_map.values()]:
Expand Down Expand Up @@ -534,6 +531,37 @@ def _load_env_vars(

return self

def _resolve_parsed_args(self, parsed_args: dict[str, list[str] | str]) -> list[str]:
selected_subcommands: list[str] = []
for field_name, val in list(parsed_args.items()):
if isinstance(val, list):
if self._is_nested_alias_path_only_workaround(parsed_args, field_name, val):
# Workaround for nested alias path environment variables not being handled.
# See https://github.com/pydantic/pydantic-settings/issues/670
continue

cli_arg = self._parser_map.get(field_name, {}).get(None)
if cli_arg and cli_arg.is_no_decode:
parsed_args[field_name] = ','.join(val)
continue

parsed_args[field_name] = self._merge_parsed_list(val, field_name)
elif field_name.endswith(':subcommand') and val is not None:
selected_subcommands.append(self._parser_map[field_name][val].dest)
elif self.cli_kebab_case == 'all':
snake_val = val.replace('-', '_')
cli_arg = self._parser_map.get(field_name, {}).get(None)
if (
cli_arg
and cli_arg.field_info.annotation
and (snake_val in cli_arg.get_enum_names(cli_arg.field_info.annotation, False))
):
if '_' in val:
raise ValueError(f'Input should be kebab-case "{val.replace("_", "-")}", not "{val}"')
parsed_args[field_name] = snake_val

return selected_subcommands

def _is_nested_alias_path_only_workaround(
self, parsed_args: dict[str, list[str] | str], field_name: str, val: list[str]
) -> bool:
Expand Down Expand Up @@ -1198,7 +1226,9 @@ def _metavar_format_recurse(self, obj: Any) -> str:
elif typing_objects.is_literal(origin):
return self._metavar_format_choices(list(map(str, self._get_modified_args(obj))))
elif _lenient_issubclass(obj, Enum):
return self._metavar_format_choices([val.name for val in obj])
return self._metavar_format_choices(
[_CliArg.get_kebab_case(val.name, self.cli_kebab_case == 'all') for val in obj]
)
elif isinstance(obj, _WithArgsTypes):
return self._metavar_format_choices(
list(map(self._metavar_format_recurse, self._get_modified_args(obj))),
Expand Down
33 changes: 33 additions & 0 deletions tests/test_source_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2589,6 +2589,39 @@ class Root(BaseModel):
)


def test_cli_kebab_case_enums():
class Example1(IntEnum):
example_a = 0
example_b = 1

class Example2(IntEnum):
example_c = 2
example_d = 3

class SettingsNoEnum(BaseSettings):
model_config = SettingsConfigDict(cli_kebab_case='no_enums')
example: Union[Example1, Example2]
mybool: bool

class SettingsAll(BaseSettings):
model_config = SettingsConfigDict(cli_kebab_case='all')
example: Union[Example1, Example2]
mybool: bool

assert CliApp.run(
SettingsNoEnum,
cli_args=['--example', 'example_a', '--mybool=true'],
).model_dump() == {'example': Example1.example_a, 'mybool': True}

assert CliApp.run(SettingsAll, cli_args=['--example', 'example-c', '--mybool=true']).model_dump() == {
'example': Example2.example_c,
'mybool': True,
}

with pytest.raises(ValueError, match='Input should be kebab-case "example-a", not "example_a"'):
CliApp.run(SettingsAll, cli_args=['--example', 'example_a', '--mybool=true'])


def test_cli_with_unbalanced_brackets_in_json_string():
class StrToStrDictOptions(BaseSettings):
nested: dict[str, str]
Expand Down
Loading