Skip to content

Commit ccb0d6f

Browse files
committed
Add support for CliMutuallyExclusiveGroup.
1 parent 87ad4db commit ccb0d6f

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

pydantic_settings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
AzureKeyVaultSettingsSource,
55
CliExplicitFlag,
66
CliImplicitFlag,
7+
CliMutuallyExclusiveGroup,
78
CliPositionalArg,
89
CliSettingsSource,
910
CliSubCommand,
@@ -34,6 +35,7 @@
3435
'CliPositionalArg',
3536
'CliExplicitFlag',
3637
'CliImplicitFlag',
38+
'CliMutuallyExclusiveGroup',
3739
'InitSettingsSource',
3840
'JsonConfigSettingsSource',
3941
'PyprojectTomlConfigSettingsSource',

pydantic_settings/sources.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ def error(self, message: str) -> NoReturn:
149149
super().error(message)
150150

151151

152+
class CliMutuallyExclusiveGroup(BaseModel):
153+
pass
154+
155+
152156
T = TypeVar('T')
153157
CliSubCommand = Annotated[Union[T, None], _CliSubCommand]
154158
CliPositionalArg = Annotated[T, _CliPositionalArg]
@@ -1515,6 +1519,25 @@ def none_parser_method(*args: Any, **kwargs: Any) -> Any:
15151519
else:
15161520
return parser_method
15171521

1522+
def _connect_group_method(self, add_argument_group_method: Callable[..., Any] | None) -> Callable[..., Any]:
1523+
add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method')
1524+
1525+
def add_group_method(parser: Any, model: type[BaseModel], **kwargs: Any) -> Any:
1526+
if not issubclass(model, CliMutuallyExclusiveGroup):
1527+
kwargs.pop('required')
1528+
return add_argument_group(parser, **kwargs)
1529+
else:
1530+
group = add_argument_group(
1531+
parser, **{arg: kwargs.pop(arg) for arg in ['title', 'description'] if arg in kwargs}
1532+
)
1533+
if not hasattr(group, 'add_mutually_exclusive_group'):
1534+
raise SettingsError(
1535+
'cannot connect CLI settings source root parser: add_mutually_exclusive_group is set to `None` but is needed for connecting'
1536+
)
1537+
return group.add_mutually_exclusive_group(**kwargs)
1538+
1539+
return add_group_method
1540+
15181541
def _connect_root_parser(
15191542
self,
15201543
root_parser: T,
@@ -1533,7 +1556,7 @@ def _parse_known_args(*args: Any, **kwargs: Any) -> Namespace:
15331556
parse_args_method = _parse_known_args if self.cli_ignore_unknown_args else ArgumentParser.parse_args
15341557
self._parse_args = self._connect_parser_method(parse_args_method, 'parsed_args_method')
15351558
self._add_argument = self._connect_parser_method(add_argument_method, 'add_argument_method')
1536-
self._add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method')
1559+
self._add_group = self._connect_group_method(add_argument_group_method)
15371560
self._add_parser = self._connect_parser_method(add_parser_method, 'add_parser_method')
15381561
self._add_subparsers = self._connect_parser_method(add_subparsers_method, 'add_subparsers_method')
15391562
self._formatter_class = formatter_class
@@ -1656,6 +1679,7 @@ def _add_parser_args(
16561679
if is_parser_submodel:
16571680
self._add_parser_submodels(
16581681
parser,
1682+
model,
16591683
sub_models,
16601684
added_args,
16611685
arg_prefix,
@@ -1671,7 +1695,7 @@ def _add_parser_args(
16711695
elif not is_alias_path_only:
16721696
if group is not None:
16731697
if isinstance(group, dict):
1674-
group = self._add_argument_group(parser, **group)
1698+
group = self._add_group(parser, model, **group)
16751699
added_args += list(arg_names)
16761700
self._add_argument(group, *(f'{flag_prefix[:len(name)]}{name}' for name in arg_names), **kwargs)
16771701
else:
@@ -1680,7 +1704,7 @@ def _add_parser_args(
16801704
parser, *(f'{flag_prefix[:len(name)]}{name}' for name in arg_names), **kwargs
16811705
)
16821706

1683-
self._add_parser_alias_paths(parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group)
1707+
self._add_parser_alias_paths(parser, model, alias_path_args, added_args, arg_prefix, subcommand_prefix, group)
16841708
return parser
16851709

16861710
def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, model_default: Any) -> None:
@@ -1715,6 +1739,7 @@ def _get_arg_names(
17151739
def _add_parser_submodels(
17161740
self,
17171741
parser: Any,
1742+
model: type[BaseModel],
17181743
sub_models: list[type[BaseModel]],
17191744
added_args: list[str],
17201745
arg_prefix: str,
@@ -1727,10 +1752,18 @@ def _add_parser_submodels(
17271752
alias_names: tuple[str, ...],
17281753
model_default: Any,
17291754
) -> None:
1755+
if issubclass(model, CliMutuallyExclusiveGroup):
1756+
# Argparse has deprecated "calling add_argument_group() or add_mutually_exclusive_group() on a
1757+
# mutually exclusive group" (https://docs.python.org/3/library/argparse.html#mutual-exclusion).
1758+
# Since nested models result in a group add, raise an exception for nested models in a mutually
1759+
# exclusive group.
1760+
raise SettingsError('cannot have nested models in a CliMutuallyExclusiveGroup')
1761+
17301762
model_group: Any = None
17311763
model_group_kwargs: dict[str, Any] = {}
17321764
model_group_kwargs['title'] = f'{arg_names[0]} options'
17331765
model_group_kwargs['description'] = field_info.description
1766+
model_group_kwargs['required'] = kwargs['required']
17341767
if self.cli_use_class_docs_for_groups and len(sub_models) == 1:
17351768
model_group_kwargs['description'] = None if sub_models[0].__doc__ is None else dedent(sub_models[0].__doc__)
17361769

@@ -1753,7 +1786,7 @@ def _add_parser_submodels(
17531786
if not self.cli_avoid_json:
17541787
added_args.append(arg_names[0])
17551788
kwargs['help'] = f'set {arg_names[0]} from JSON string'
1756-
model_group = self._add_argument_group(parser, **model_group_kwargs)
1789+
model_group = self._add_group(parser, model, **model_group_kwargs)
17571790
self._add_argument(model_group, *(f'{flag_prefix}{name}' for name in arg_names), **kwargs)
17581791
for model in sub_models:
17591792
self._add_parser_args(
@@ -1770,6 +1803,7 @@ def _add_parser_submodels(
17701803
def _add_parser_alias_paths(
17711804
self,
17721805
parser: Any,
1806+
model: type[BaseModel],
17731807
alias_path_args: dict[str, str],
17741808
added_args: list[str],
17751809
arg_prefix: str,
@@ -1779,7 +1813,7 @@ def _add_parser_alias_paths(
17791813
if alias_path_args:
17801814
context = parser
17811815
if group is not None:
1782-
context = self._add_argument_group(parser, **group) if isinstance(group, dict) else group
1816+
context = self._add_group(parser, model, **group) if isinstance(group, dict) else group
17831817
is_nested_alias_path = arg_prefix.endswith('.')
17841818
arg_prefix = arg_prefix[:-1] if is_nested_alias_path else arg_prefix
17851819
for name, metavar in alias_path_args.items():

0 commit comments

Comments
 (0)