Skip to content

Commit a9eb22e

Browse files
kschwabKyle Schwab
andauthored
Add CLI subcommand union and alias support (#380)
Co-authored-by: Kyle Schwab <[email protected]>
1 parent 12d85cf commit a9eb22e

File tree

3 files changed

+413
-80
lines changed

3 files changed

+413
-80
lines changed

docs/index.md

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ not required, set the `is_required` flag to `False` to disable raising an error
757757
subcommands](https://docs.python.org/3/library/argparse.html#sub-commands).
758758

759759
!!! note
760-
`CliSubCommand` and `CliPositionalArg` are always case sensitive and do not support aliases.
760+
`CliSubCommand` and `CliPositionalArg` are always case sensitive.
761761

762762
```py
763763
import sys
@@ -817,6 +817,64 @@ assert get_subcommand(cmd).model_dump() == {
817817
}
818818
```
819819

820+
The `CliSubCommand` and `CliPositionalArg` annotations also support union operations and aliases. For unions of Pydantic
821+
models, it is important to remember the [nuances](https://docs.pydantic.dev/latest/concepts/unions/) that can arise
822+
during validation. Specifically, for unions of subcommands that are identical in content, it is recommended to break
823+
them out into separate `CliSubCommand` fields to avoid any complications. Lastly, the derived subcommand names from
824+
unions will be the names of the Pydantic model classes themselves.
825+
826+
When assigning aliases to `CliSubCommand` or `CliPositionalArg` fields, only a single alias can be assigned. For
827+
non-union subcommands, aliasing will change the displayed help text and subcommand name. Conversely, for union
828+
subcommands, aliasing will have no tangible effect from the perspective of the CLI settings source. Lastly, for
829+
positional arguments, aliasing will change the CLI help text displayed for the field.
830+
831+
```py
832+
import sys
833+
from typing import Union
834+
835+
from pydantic import BaseModel, Field
836+
837+
from pydantic_settings import (
838+
BaseSettings,
839+
CliPositionalArg,
840+
CliSubCommand,
841+
get_subcommand,
842+
)
843+
844+
845+
class Alpha(BaseModel):
846+
"""Apha Help"""
847+
848+
cmd_alpha: CliPositionalArg[str] = Field(alias='alpha-cmd')
849+
850+
851+
class Beta(BaseModel):
852+
"""Beta Help"""
853+
854+
opt_beta: str = Field(alias='opt-beta')
855+
856+
857+
class Gamma(BaseModel):
858+
"""Gamma Help"""
859+
860+
opt_gamma: str = Field(alias='opt-gamma')
861+
862+
863+
class Root(BaseSettings, cli_parse_args=True, cli_exit_on_error=False):
864+
alpha_or_beta: CliSubCommand[Union[Alpha, Beta]] = Field(alias='alpha-or-beta-cmd')
865+
gamma: CliSubCommand[Gamma] = Field(alias='gamma-cmd')
866+
867+
868+
sys.argv = ['example.py', 'Alpha', 'hello']
869+
assert get_subcommand(Root()).model_dump() == {'cmd_alpha': 'hello'}
870+
871+
sys.argv = ['example.py', 'Beta', '--opt-beta=hey']
872+
assert get_subcommand(Root()).model_dump() == {'opt_beta': 'hey'}
873+
874+
sys.argv = ['example.py', 'gamma-cmd', '--opt-gamma=hi']
875+
assert get_subcommand(Root()).model_dump() == {'opt_gamma': 'hi'}
876+
```
877+
820878
### Customizing the CLI Experience
821879

822880
The below flags can be used to customise the CLI experience to your needs.
@@ -861,9 +919,11 @@ Additionally, the provided `CliImplicitFlag` and `CliExplicitFlag` annotations c
861919
when necessary.
862920

863921
!!! note
864-
For `python < 3.9`:
865-
* The `--no-flag` option is not generated due to an underlying `argparse` limitation.
866-
* The `CliImplicitFlag` and `CliExplicitFlag` annotations can only be applied to optional bool fields.
922+
For `python < 3.9` the `--no-flag` option is not generated due to an underlying `argparse` limitation.
923+
924+
!!! note
925+
For `python < 3.9` the `CliImplicitFlag` and `CliExplicitFlag` annotations can only be applied to optional boolean
926+
fields.
867927

868928
```py
869929
from pydantic_settings import BaseSettings, CliExplicitFlag, CliImplicitFlag

pydantic_settings/sources.py

Lines changed: 99 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
if sys.version_info >= (3, 9):
1414
from argparse import BooleanOptionalAction
1515
from argparse import SUPPRESS, ArgumentParser, Namespace, RawDescriptionHelpFormatter, _SubParsersAction
16-
from collections import deque
16+
from collections import defaultdict, deque
1717
from dataclasses import asdict, is_dataclass
1818
from enum import Enum
1919
from pathlib import Path
@@ -1239,12 +1239,14 @@ def _load_env_vars(
12391239
if isinstance(val, list):
12401240
parsed_args[field_name] = self._merge_parsed_list(val, field_name)
12411241
elif field_name.endswith(':subcommand') and val is not None:
1242-
selected_subcommands.append(field_name.split(':')[0] + val)
1242+
subcommand_name = field_name.split(':')[0] + val
1243+
subcommand_dest = self._cli_subcommands[field_name][subcommand_name]
1244+
selected_subcommands.append(subcommand_dest)
12431245

12441246
for subcommands in self._cli_subcommands.values():
1245-
for subcommand in subcommands:
1246-
if subcommand not in selected_subcommands:
1247-
parsed_args[subcommand] = self.cli_parse_none_str
1247+
for subcommand_dest in subcommands.values():
1248+
if subcommand_dest not in selected_subcommands:
1249+
parsed_args[subcommand_dest] = self.cli_parse_none_str
12481250

12491251
parsed_args = {key: val for key, val in parsed_args.items() if not key.endswith(':subcommand')}
12501252
if selected_subcommands:
@@ -1389,26 +1391,26 @@ def _get_sub_models(self, model: type[BaseModel], field_name: str, field_info: F
13891391
sub_models.append(type_) # type: ignore
13901392
return sub_models
13911393

1392-
def _get_resolved_names(
1394+
def _get_alias_names(
13931395
self, field_name: str, field_info: FieldInfo, alias_path_args: dict[str, str]
13941396
) -> tuple[tuple[str, ...], bool]:
1395-
resolved_names: list[str] = []
1397+
alias_names: list[str] = []
13961398
is_alias_path_only: bool = True
13971399
if not any((field_info.alias, field_info.validation_alias)):
1398-
resolved_names += [field_name]
1400+
alias_names += [field_name]
13991401
is_alias_path_only = False
14001402
else:
14011403
new_alias_paths: list[AliasPath] = []
14021404
for alias in (field_info.alias, field_info.validation_alias):
14031405
if alias is None:
14041406
continue
14051407
elif isinstance(alias, str):
1406-
resolved_names.append(alias)
1408+
alias_names.append(alias)
14071409
is_alias_path_only = False
14081410
elif isinstance(alias, AliasChoices):
14091411
for name in alias.choices:
14101412
if isinstance(name, str):
1411-
resolved_names.append(name)
1413+
alias_names.append(name)
14121414
is_alias_path_only = False
14131415
else:
14141416
new_alias_paths.append(name)
@@ -1418,11 +1420,11 @@ def _get_resolved_names(
14181420
name = cast(str, alias_path.path[0])
14191421
name = name.lower() if not self.case_sensitive else name
14201422
alias_path_args[name] = 'dict' if len(alias_path.path) > 2 else 'list'
1421-
if not resolved_names and is_alias_path_only:
1422-
resolved_names.append(name)
1423+
if not alias_names and is_alias_path_only:
1424+
alias_names.append(name)
14231425
if not self.case_sensitive:
1424-
resolved_names = [resolved_name.lower() for resolved_name in resolved_names]
1425-
return tuple(dict.fromkeys(resolved_names)), is_alias_path_only
1426+
alias_names = [alias_name.lower() for alias_name in alias_names]
1427+
return tuple(dict.fromkeys(alias_names)), is_alias_path_only
14261428

14271429
def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None:
14281430
if _CliImplicitFlag in field_info.metadata:
@@ -1447,22 +1449,24 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
14471449
if _CliSubCommand in field_info.metadata:
14481450
if not field_info.is_required():
14491451
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value')
1450-
elif any((field_info.alias, field_info.validation_alias)):
1451-
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has an alias')
14521452
else:
1453+
alias_names, *_ = self._get_alias_names(field_name, field_info, {})
1454+
if len(alias_names) > 1:
1455+
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple aliases')
14531456
field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)]
1454-
if len(field_types) != 1:
1455-
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple types')
1456-
elif not (is_model_class(field_types[0]) or is_pydantic_dataclass(field_types[0])):
1457-
raise SettingsError(
1458-
f'subcommand argument {model.__name__}.{field_name} is not derived from BaseModel'
1459-
)
1457+
for field_type in field_types:
1458+
if not (is_model_class(field_type) or is_pydantic_dataclass(field_type)):
1459+
raise SettingsError(
1460+
f'subcommand argument {model.__name__}.{field_name} has type not derived from BaseModel'
1461+
)
14601462
subcommand_args.append((field_name, field_info))
14611463
elif _CliPositionalArg in field_info.metadata:
14621464
if not field_info.is_required():
14631465
raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value')
1464-
elif any((field_info.alias, field_info.validation_alias)):
1465-
raise SettingsError(f'positional argument {model.__name__}.{field_name} has an alias')
1466+
else:
1467+
alias_names, *_ = self._get_alias_names(field_name, field_info, {})
1468+
if len(alias_names) > 1:
1469+
raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases')
14661470
positional_args.append((field_name, field_info))
14671471
else:
14681472
self._verify_cli_flag_annotations(model, field_name, field_info)
@@ -1529,7 +1533,7 @@ def _connect_root_parser(
15291533
self._add_subparsers = self._connect_parser_method(add_subparsers_method, 'add_subparsers_method')
15301534
self._formatter_class = formatter_class
15311535
self._cli_dict_args: dict[str, type[Any] | None] = {}
1532-
self._cli_subcommands: dict[str, list[str]] = {}
1536+
self._cli_subcommands: defaultdict[str, dict[str, str]] = defaultdict(dict)
15331537
self._add_parser_args(
15341538
parser=self.root_parser,
15351539
model=self.settings_cls,
@@ -1556,64 +1560,93 @@ def _add_parser_args(
15561560
alias_path_args: dict[str, str] = {}
15571561
for field_name, field_info in self._sort_arg_fields(model):
15581562
sub_models: list[type[BaseModel]] = self._get_sub_models(model, field_name, field_info)
1563+
alias_names, is_alias_path_only = self._get_alias_names(field_name, field_info, alias_path_args)
1564+
preferred_alias = alias_names[0]
15591565
if _CliSubCommand in field_info.metadata:
1560-
if subparsers is None:
1561-
subparsers = self._add_subparsers(parser, title='subcommands', dest=f'{arg_prefix}:subcommand')
1562-
self._cli_subcommands[f'{arg_prefix}:subcommand'] = [f'{arg_prefix}{field_name}']
1563-
else:
1564-
self._cli_subcommands[f'{arg_prefix}:subcommand'].append(f'{arg_prefix}{field_name}')
1565-
if hasattr(subparsers, 'metavar'):
1566-
metavar = ','.join(self._cli_subcommands[f'{arg_prefix}:subcommand'])
1567-
subparsers.metavar = f'{{{metavar}}}'
1568-
1569-
model = sub_models[0]
1570-
self._add_parser_args(
1571-
parser=self._add_parser(
1572-
subparsers,
1573-
field_name,
1574-
help=field_info.description,
1575-
formatter_class=self._formatter_class,
1576-
description=None if model.__doc__ is None else dedent(model.__doc__),
1577-
),
1578-
model=model,
1579-
added_args=[],
1580-
arg_prefix=f'{arg_prefix}{field_name}.',
1581-
subcommand_prefix=f'{subcommand_prefix}{field_name}.',
1582-
group=None,
1583-
alias_prefixes=[],
1584-
model_default=PydanticUndefined,
1585-
)
1566+
for model in sub_models:
1567+
subcommand_alias = model.__name__ if len(sub_models) > 1 else preferred_alias
1568+
subcommand_name = f'{arg_prefix}{subcommand_alias}'
1569+
subcommand_dest = f'{arg_prefix}{preferred_alias}'
1570+
self._cli_subcommands[f'{arg_prefix}:subcommand'][subcommand_name] = subcommand_dest
1571+
1572+
subcommand_help = None if len(sub_models) > 1 else field_info.description
1573+
if self.cli_use_class_docs_for_groups:
1574+
subcommand_help = None if model.__doc__ is None else dedent(model.__doc__)
1575+
1576+
subparsers = (
1577+
self._add_subparsers(
1578+
parser,
1579+
title='subcommands',
1580+
dest=f'{arg_prefix}:subcommand',
1581+
description=field_info.description if len(sub_models) > 1 else None,
1582+
)
1583+
if subparsers is None
1584+
else subparsers
1585+
)
1586+
1587+
if hasattr(subparsers, 'metavar'):
1588+
subparsers.metavar = (
1589+
f'{subparsers.metavar[:-1]},{subcommand_alias}}}'
1590+
if subparsers.metavar
1591+
else f'{{{subcommand_alias}}}'
1592+
)
1593+
1594+
self._add_parser_args(
1595+
parser=self._add_parser(
1596+
subparsers,
1597+
subcommand_alias,
1598+
help=subcommand_help,
1599+
formatter_class=self._formatter_class,
1600+
description=None if model.__doc__ is None else dedent(model.__doc__),
1601+
),
1602+
model=model,
1603+
added_args=[],
1604+
arg_prefix=f'{arg_prefix}{preferred_alias}.',
1605+
subcommand_prefix=f'{subcommand_prefix}{preferred_alias}.',
1606+
group=None,
1607+
alias_prefixes=[],
1608+
model_default=PydanticUndefined,
1609+
)
15861610
else:
1587-
resolved_names, is_alias_path_only = self._get_resolved_names(field_name, field_info, alias_path_args)
15881611
arg_flag: str = '--'
1612+
is_append_action = _annotation_contains_types(
1613+
field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True
1614+
)
1615+
is_parser_submodel = sub_models and not is_append_action
15891616
kwargs: dict[str, Any] = {}
15901617
kwargs['default'] = SUPPRESS
15911618
kwargs['help'] = self._help_format(field_name, field_info, model_default)
1592-
kwargs['dest'] = f'{arg_prefix}{resolved_names[0]}'
15931619
kwargs['metavar'] = self._metavar_format(field_info.annotation)
15941620
kwargs['required'] = (
15951621
self.cli_enforce_required and field_info.is_required() and model_default is PydanticUndefined
15961622
)
1623+
kwargs['dest'] = (
1624+
# Strip prefix if validation alias is set and value is not complex.
1625+
# Related https://github.com/pydantic/pydantic-settings/pull/25
1626+
f'{arg_prefix}{preferred_alias}'[self.env_prefix_len :]
1627+
if arg_prefix and field_info.validation_alias is not None and not is_parser_submodel
1628+
else f'{arg_prefix}{preferred_alias}'
1629+
)
1630+
15971631
if kwargs['dest'] in added_args:
15981632
continue
1599-
if _annotation_contains_types(
1600-
field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True
1601-
):
1633+
1634+
if is_append_action:
16021635
kwargs['action'] = 'append'
16031636
if _annotation_contains_types(field_info.annotation, (dict, Mapping), is_strip_annotated=True):
16041637
self._cli_dict_args[kwargs['dest']] = field_info.annotation
16051638

1606-
arg_names = self._get_arg_names(arg_prefix, subcommand_prefix, alias_prefixes, resolved_names)
1639+
arg_names = self._get_arg_names(arg_prefix, subcommand_prefix, alias_prefixes, alias_names)
16071640
if _CliPositionalArg in field_info.metadata:
1608-
kwargs['metavar'] = resolved_names[0].upper()
1641+
kwargs['metavar'] = preferred_alias.upper()
16091642
arg_names = [kwargs['dest']]
16101643
del kwargs['dest']
16111644
del kwargs['required']
16121645
arg_flag = ''
16131646

16141647
self._convert_bool_flag(kwargs, field_info, model_default)
16151648

1616-
if sub_models and kwargs.get('action') != 'append':
1649+
if is_parser_submodel:
16171650
self._add_parser_submodels(
16181651
parser,
16191652
sub_models,
@@ -1625,14 +1658,10 @@ def _add_parser_args(
16251658
kwargs,
16261659
field_name,
16271660
field_info,
1628-
resolved_names,
1661+
alias_names,
16291662
model_default=model_default,
16301663
)
16311664
elif not is_alias_path_only:
1632-
if arg_prefix and field_info.validation_alias is not None:
1633-
# Strip prefix if validation alias is set and value is not complex.
1634-
# Related https://github.com/pydantic/pydantic-settings/pull/25
1635-
kwargs['dest'] = kwargs['dest'][self.env_prefix_len :]
16361665
if group is not None:
16371666
if isinstance(group, dict):
16381667
group = self._add_argument_group(parser, **group)
@@ -1662,11 +1691,11 @@ def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, mode
16621691
)
16631692

16641693
def _get_arg_names(
1665-
self, arg_prefix: str, subcommand_prefix: str, alias_prefixes: list[str], resolved_names: tuple[str, ...]
1694+
self, arg_prefix: str, subcommand_prefix: str, alias_prefixes: list[str], alias_names: tuple[str, ...]
16661695
) -> list[str]:
16671696
arg_names: list[str] = []
16681697
for prefix in [arg_prefix] + alias_prefixes:
1669-
for name in resolved_names:
1698+
for name in alias_names:
16701699
arg_names.append(
16711700
f'{prefix}{name}'
16721701
if subcommand_prefix == self.env_prefix
@@ -1686,7 +1715,7 @@ def _add_parser_submodels(
16861715
kwargs: dict[str, Any],
16871716
field_name: str,
16881717
field_info: FieldInfo,
1689-
resolved_names: tuple[str, ...],
1718+
alias_names: tuple[str, ...],
16901719
model_default: Any,
16911720
) -> None:
16921721
model_group: Any = None
@@ -1711,6 +1740,7 @@ def _add_parser_submodels(
17111740
else:
17121741
model_group_kwargs['description'] = desc_header
17131742

1743+
preferred_alias = alias_names[0]
17141744
if not self.cli_avoid_json:
17151745
added_args.append(arg_names[0])
17161746
kwargs['help'] = f'set {arg_names[0]} from JSON string'
@@ -1721,10 +1751,10 @@ def _add_parser_submodels(
17211751
parser=parser,
17221752
model=model,
17231753
added_args=added_args,
1724-
arg_prefix=f'{arg_prefix}{resolved_names[0]}.',
1754+
arg_prefix=f'{arg_prefix}{preferred_alias}.',
17251755
subcommand_prefix=subcommand_prefix,
17261756
group=model_group if model_group else model_group_kwargs,
1727-
alias_prefixes=[f'{arg_prefix}{name}.' for name in resolved_names[1:]],
1757+
alias_prefixes=[f'{arg_prefix}{name}.' for name in alias_names[1:]],
17281758
model_default=model_default,
17291759
)
17301760

0 commit comments

Comments
 (0)