Skip to content

Commit 68563ed

Browse files
authored
Support for enum kebab case. (#686)
1 parent 3e66430 commit 68563ed

File tree

4 files changed

+92
-27
lines changed

4 files changed

+92
-27
lines changed

docs/index.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1354,7 +1354,9 @@ print(Settings().model_dump())
13541354

13551355
#### CLI Kebab Case for Arguments
13561356

1357-
Change whether CLI arguments should use kebab case by enabling `cli_kebab_case`.
1357+
Change whether CLI arguments should use kebab case by enabling `cli_kebab_case`. By default, `cli_kebab_case=True` will
1358+
ignore enum fields, and is equivalent to `cli_kebab_case='no_enums'`. To apply kebab case to everything, including
1359+
enums, use `cli_kebab_case='all'`.
13581360

13591361
```py
13601362
import sys

pydantic_settings/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from argparse import Namespace
88
from collections.abc import Mapping
99
from types import SimpleNamespace
10-
from typing import Any, ClassVar, TypeVar
10+
from typing import Any, ClassVar, Literal, TypeVar
1111

1212
from pydantic import ConfigDict
1313
from pydantic._internal._config import config_keys
@@ -62,7 +62,7 @@ class SettingsConfigDict(ConfigDict, total=False):
6262
cli_flag_prefix_char: str
6363
cli_implicit_flags: bool | None
6464
cli_ignore_unknown_args: bool | None
65-
cli_kebab_case: bool | None
65+
cli_kebab_case: bool | Literal['all', 'no_enums'] | None
6666
cli_shortcuts: Mapping[str, str | list[str]] | None
6767
secrets_dir: PathType | None
6868
json_file: PathType | None
@@ -185,7 +185,7 @@ def __init__(
185185
_cli_flag_prefix_char: str | None = None,
186186
_cli_implicit_flags: bool | None = None,
187187
_cli_ignore_unknown_args: bool | None = None,
188-
_cli_kebab_case: bool | None = None,
188+
_cli_kebab_case: bool | Literal['all', 'no_enums'] | None = None,
189189
_cli_shortcuts: Mapping[str, str | list[str]] | None = None,
190190
_secrets_dir: PathType | None = None,
191191
**values: Any,
@@ -272,7 +272,7 @@ def _settings_build_values(
272272
_cli_flag_prefix_char: str | None = None,
273273
_cli_implicit_flags: bool | None = None,
274274
_cli_ignore_unknown_args: bool | None = None,
275-
_cli_kebab_case: bool | None = None,
275+
_cli_kebab_case: bool | Literal['all', 'no_enums'] | None = None,
276276
_cli_shortcuts: Mapping[str, str | list[str]] | None = None,
277277
_secrets_dir: PathType | None = None,
278278
) -> dict[str, Any]:

pydantic_settings/sources/providers/cli.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Any,
2828
Callable,
2929
Generic,
30+
Literal,
3031
NoReturn,
3132
Optional,
3233
TypeVar,
@@ -94,7 +95,7 @@ class _CliArg(BaseModel):
9495
arg_prefix: str
9596
case_sensitive: bool
9697
hide_none_type: bool
97-
kebab_case: bool
98+
kebab_case: Optional[Union[bool, Literal['all', 'no_enums']]]
9899
enable_decoding: Optional[bool]
99100
env_prefix_len: int
100101
args: list[str] = []
@@ -131,8 +132,20 @@ def __init__(
131132
parser_map[self.field_info][index] = parser_map[alias_path_dest][index]
132133

133134
@classmethod
134-
def get_kebab_case(cls, name: str, kebab_case: Optional[bool]) -> str:
135-
return name.replace('_', '-') if kebab_case else name
135+
def get_kebab_case(cls, name: str, kebab_case: Optional[Union[bool, Literal['all', 'no_enums']]]) -> str:
136+
return name.replace('_', '-') if kebab_case not in (None, False) else name
137+
138+
@classmethod
139+
def get_enum_names(
140+
cls, annotation: type[Any], kebab_case: Optional[Union[bool, Literal['all', 'no_enums']]]
141+
) -> tuple[str, ...]:
142+
enum_names: tuple[str, ...] = ()
143+
annotation = _strip_annotated(annotation)
144+
for type_ in get_args(annotation):
145+
enum_names += cls.get_enum_names(type_, kebab_case)
146+
if annotation and _lenient_issubclass(annotation, Enum):
147+
enum_names += tuple(cls.get_kebab_case(val.name, kebab_case == 'all') for val in annotation)
148+
return enum_names
136149

137150
def subcommand_alias(self, sub_model: type[BaseModel]) -> str:
138151
return self.get_kebab_case(
@@ -294,7 +307,7 @@ def __init__(
294307
cli_flag_prefix_char: str | None = None,
295308
cli_implicit_flags: bool | None = None,
296309
cli_ignore_unknown_args: bool | None = None,
297-
cli_kebab_case: bool | None = None,
310+
cli_kebab_case: bool | Literal['all', 'no_enums'] | None = None,
298311
cli_shortcuts: Mapping[str, str | list[str]] | None = None,
299312
case_sensitive: bool | None = True,
300313
root_parser: Any = None,
@@ -490,23 +503,7 @@ def _load_env_vars(
490503
if isinstance(parsed_args, (Namespace, SimpleNamespace)):
491504
parsed_args = vars(parsed_args)
492505

493-
selected_subcommands: list[str] = []
494-
for field_name, val in list(parsed_args.items()):
495-
if isinstance(val, list):
496-
if self._is_nested_alias_path_only_workaround(parsed_args, field_name, val):
497-
# Workaround for nested alias path environment variables not being handled.
498-
# See https://github.com/pydantic/pydantic-settings/issues/670
499-
continue
500-
501-
cli_arg = self._parser_map.get(field_name, {}).get(None)
502-
if cli_arg and cli_arg.is_no_decode:
503-
parsed_args[field_name] = ','.join(val)
504-
continue
505-
506-
parsed_args[field_name] = self._merge_parsed_list(val, field_name)
507-
elif field_name.endswith(':subcommand') and val is not None:
508-
selected_subcommands.append(self._parser_map[field_name][val].dest)
509-
506+
selected_subcommands = self._resolve_parsed_args(parsed_args)
510507
for arg_dest, arg_map in self._parser_map.items():
511508
if isinstance(arg_dest, str) and arg_dest.endswith(':subcommand'):
512509
for subcommand_dest in [arg.dest for arg in arg_map.values()]:
@@ -534,6 +531,37 @@ def _load_env_vars(
534531

535532
return self
536533

534+
def _resolve_parsed_args(self, parsed_args: dict[str, list[str] | str]) -> list[str]:
535+
selected_subcommands: list[str] = []
536+
for field_name, val in list(parsed_args.items()):
537+
if isinstance(val, list):
538+
if self._is_nested_alias_path_only_workaround(parsed_args, field_name, val):
539+
# Workaround for nested alias path environment variables not being handled.
540+
# See https://github.com/pydantic/pydantic-settings/issues/670
541+
continue
542+
543+
cli_arg = self._parser_map.get(field_name, {}).get(None)
544+
if cli_arg and cli_arg.is_no_decode:
545+
parsed_args[field_name] = ','.join(val)
546+
continue
547+
548+
parsed_args[field_name] = self._merge_parsed_list(val, field_name)
549+
elif field_name.endswith(':subcommand') and val is not None:
550+
selected_subcommands.append(self._parser_map[field_name][val].dest)
551+
elif self.cli_kebab_case == 'all':
552+
snake_val = val.replace('-', '_')
553+
cli_arg = self._parser_map.get(field_name, {}).get(None)
554+
if (
555+
cli_arg
556+
and cli_arg.field_info.annotation
557+
and (snake_val in cli_arg.get_enum_names(cli_arg.field_info.annotation, False))
558+
):
559+
if '_' in val:
560+
raise ValueError(f'Input should be kebab-case "{val.replace("_", "-")}", not "{val}"')
561+
parsed_args[field_name] = snake_val
562+
563+
return selected_subcommands
564+
537565
def _is_nested_alias_path_only_workaround(
538566
self, parsed_args: dict[str, list[str] | str], field_name: str, val: list[str]
539567
) -> bool:
@@ -1198,7 +1226,9 @@ def _metavar_format_recurse(self, obj: Any) -> str:
11981226
elif typing_objects.is_literal(origin):
11991227
return self._metavar_format_choices(list(map(str, self._get_modified_args(obj))))
12001228
elif _lenient_issubclass(obj, Enum):
1201-
return self._metavar_format_choices([val.name for val in obj])
1229+
return self._metavar_format_choices(
1230+
[_CliArg.get_kebab_case(val.name, self.cli_kebab_case == 'all') for val in obj]
1231+
)
12021232
elif isinstance(obj, _WithArgsTypes):
12031233
return self._metavar_format_choices(
12041234
list(map(self._metavar_format_recurse, self._get_modified_args(obj))),

tests/test_source_cli.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2589,6 +2589,39 @@ class Root(BaseModel):
25892589
)
25902590

25912591

2592+
def test_cli_kebab_case_enums():
2593+
class Example1(IntEnum):
2594+
example_a = 0
2595+
example_b = 1
2596+
2597+
class Example2(IntEnum):
2598+
example_c = 2
2599+
example_d = 3
2600+
2601+
class SettingsNoEnum(BaseSettings):
2602+
model_config = SettingsConfigDict(cli_kebab_case='no_enums')
2603+
example: Union[Example1, Example2]
2604+
mybool: bool
2605+
2606+
class SettingsAll(BaseSettings):
2607+
model_config = SettingsConfigDict(cli_kebab_case='all')
2608+
example: Union[Example1, Example2]
2609+
mybool: bool
2610+
2611+
assert CliApp.run(
2612+
SettingsNoEnum,
2613+
cli_args=['--example', 'example_a', '--mybool=true'],
2614+
).model_dump() == {'example': Example1.example_a, 'mybool': True}
2615+
2616+
assert CliApp.run(SettingsAll, cli_args=['--example', 'example-c', '--mybool=true']).model_dump() == {
2617+
'example': Example2.example_c,
2618+
'mybool': True,
2619+
}
2620+
2621+
with pytest.raises(ValueError, match='Input should be kebab-case "example-a", not "example_a"'):
2622+
CliApp.run(SettingsAll, cli_args=['--example', 'example_a', '--mybool=true'])
2623+
2624+
25922625
def test_cli_with_unbalanced_brackets_in_json_string():
25932626
class StrToStrDictOptions(BaseSettings):
25942627
nested: dict[str, str]

0 commit comments

Comments
 (0)