|
27 | 27 | Any, |
28 | 28 | Callable, |
29 | 29 | Generic, |
| 30 | + Literal, |
30 | 31 | NoReturn, |
31 | 32 | Optional, |
32 | 33 | TypeVar, |
@@ -94,7 +95,7 @@ class _CliArg(BaseModel): |
94 | 95 | arg_prefix: str |
95 | 96 | case_sensitive: bool |
96 | 97 | hide_none_type: bool |
97 | | - kebab_case: bool |
| 98 | + kebab_case: Optional[Union[bool, Literal['all', 'no_enums']]] |
98 | 99 | enable_decoding: Optional[bool] |
99 | 100 | env_prefix_len: int |
100 | 101 | args: list[str] = [] |
@@ -131,8 +132,20 @@ def __init__( |
131 | 132 | parser_map[self.field_info][index] = parser_map[alias_path_dest][index] |
132 | 133 |
|
133 | 134 | @classmethod |
134 | | - def get_kebab_case(cls, name: str, kebab_case: Optional[bool]) -> str: |
135 | | - return name.replace('_', '-') if kebab_case else name |
| 135 | + def get_kebab_case(cls, name: str, kebab_case: Optional[Union[bool, Literal['all', 'no_enums']]]) -> str: |
| 136 | + return name.replace('_', '-') if kebab_case not in (None, False) else name |
| 137 | + |
| 138 | + @classmethod |
| 139 | + def get_enum_names( |
| 140 | + cls, annotation: type[Any], kebab_case: Optional[Union[bool, Literal['all', 'no_enums']]] |
| 141 | + ) -> tuple[str, ...]: |
| 142 | + enum_names: tuple[str, ...] = () |
| 143 | + annotation = _strip_annotated(annotation) |
| 144 | + for type_ in get_args(annotation): |
| 145 | + enum_names += cls.get_enum_names(type_, kebab_case) |
| 146 | + if annotation and _lenient_issubclass(annotation, Enum): |
| 147 | + enum_names += tuple(cls.get_kebab_case(val.name, kebab_case == 'all') for val in annotation) |
| 148 | + return enum_names |
136 | 149 |
|
137 | 150 | def subcommand_alias(self, sub_model: type[BaseModel]) -> str: |
138 | 151 | return self.get_kebab_case( |
@@ -294,7 +307,7 @@ def __init__( |
294 | 307 | cli_flag_prefix_char: str | None = None, |
295 | 308 | cli_implicit_flags: bool | None = None, |
296 | 309 | cli_ignore_unknown_args: bool | None = None, |
297 | | - cli_kebab_case: bool | None = None, |
| 310 | + cli_kebab_case: bool | Literal['all', 'no_enums'] | None = None, |
298 | 311 | cli_shortcuts: Mapping[str, str | list[str]] | None = None, |
299 | 312 | case_sensitive: bool | None = True, |
300 | 313 | root_parser: Any = None, |
@@ -490,23 +503,7 @@ def _load_env_vars( |
490 | 503 | if isinstance(parsed_args, (Namespace, SimpleNamespace)): |
491 | 504 | parsed_args = vars(parsed_args) |
492 | 505 |
|
493 | | - selected_subcommands: list[str] = [] |
494 | | - for field_name, val in list(parsed_args.items()): |
495 | | - if isinstance(val, list): |
496 | | - if self._is_nested_alias_path_only_workaround(parsed_args, field_name, val): |
497 | | - # Workaround for nested alias path environment variables not being handled. |
498 | | - # See https://github.com/pydantic/pydantic-settings/issues/670 |
499 | | - continue |
500 | | - |
501 | | - cli_arg = self._parser_map.get(field_name, {}).get(None) |
502 | | - if cli_arg and cli_arg.is_no_decode: |
503 | | - parsed_args[field_name] = ','.join(val) |
504 | | - continue |
505 | | - |
506 | | - parsed_args[field_name] = self._merge_parsed_list(val, field_name) |
507 | | - elif field_name.endswith(':subcommand') and val is not None: |
508 | | - selected_subcommands.append(self._parser_map[field_name][val].dest) |
509 | | - |
| 506 | + selected_subcommands = self._resolve_parsed_args(parsed_args) |
510 | 507 | for arg_dest, arg_map in self._parser_map.items(): |
511 | 508 | if isinstance(arg_dest, str) and arg_dest.endswith(':subcommand'): |
512 | 509 | for subcommand_dest in [arg.dest for arg in arg_map.values()]: |
@@ -534,6 +531,37 @@ def _load_env_vars( |
534 | 531 |
|
535 | 532 | return self |
536 | 533 |
|
| 534 | + def _resolve_parsed_args(self, parsed_args: dict[str, list[str] | str]) -> list[str]: |
| 535 | + selected_subcommands: list[str] = [] |
| 536 | + for field_name, val in list(parsed_args.items()): |
| 537 | + if isinstance(val, list): |
| 538 | + if self._is_nested_alias_path_only_workaround(parsed_args, field_name, val): |
| 539 | + # Workaround for nested alias path environment variables not being handled. |
| 540 | + # See https://github.com/pydantic/pydantic-settings/issues/670 |
| 541 | + continue |
| 542 | + |
| 543 | + cli_arg = self._parser_map.get(field_name, {}).get(None) |
| 544 | + if cli_arg and cli_arg.is_no_decode: |
| 545 | + parsed_args[field_name] = ','.join(val) |
| 546 | + continue |
| 547 | + |
| 548 | + parsed_args[field_name] = self._merge_parsed_list(val, field_name) |
| 549 | + elif field_name.endswith(':subcommand') and val is not None: |
| 550 | + selected_subcommands.append(self._parser_map[field_name][val].dest) |
| 551 | + elif self.cli_kebab_case == 'all': |
| 552 | + snake_val = val.replace('-', '_') |
| 553 | + cli_arg = self._parser_map.get(field_name, {}).get(None) |
| 554 | + if ( |
| 555 | + cli_arg |
| 556 | + and cli_arg.field_info.annotation |
| 557 | + and (snake_val in cli_arg.get_enum_names(cli_arg.field_info.annotation, False)) |
| 558 | + ): |
| 559 | + if '_' in val: |
| 560 | + raise ValueError(f'Input should be kebab-case "{val.replace("_", "-")}", not "{val}"') |
| 561 | + parsed_args[field_name] = snake_val |
| 562 | + |
| 563 | + return selected_subcommands |
| 564 | + |
537 | 565 | def _is_nested_alias_path_only_workaround( |
538 | 566 | self, parsed_args: dict[str, list[str] | str], field_name: str, val: list[str] |
539 | 567 | ) -> bool: |
@@ -1198,7 +1226,9 @@ def _metavar_format_recurse(self, obj: Any) -> str: |
1198 | 1226 | elif typing_objects.is_literal(origin): |
1199 | 1227 | return self._metavar_format_choices(list(map(str, self._get_modified_args(obj)))) |
1200 | 1228 | elif _lenient_issubclass(obj, Enum): |
1201 | | - return self._metavar_format_choices([val.name for val in obj]) |
| 1229 | + return self._metavar_format_choices( |
| 1230 | + [_CliArg.get_kebab_case(val.name, self.cli_kebab_case == 'all') for val in obj] |
| 1231 | + ) |
1202 | 1232 | elif isinstance(obj, _WithArgsTypes): |
1203 | 1233 | return self._metavar_format_choices( |
1204 | 1234 | list(map(self._metavar_format_recurse, self._get_modified_args(obj))), |
|
0 commit comments