Skip to content

Commit 47924f5

Browse files
authored
Add get_subcommand function. (#341)
1 parent cabcdee commit 47924f5

File tree

4 files changed

+139
-101
lines changed

4 files changed

+139
-101
lines changed

docs/index.md

Lines changed: 35 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,9 @@ Subcommands and positional arguments are expressed using the `CliSubCommand` and
747747
annotations can only be applied to required fields (i.e. fields that do not have a default value). Furthermore,
748748
subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses `dataclass`.
749749

750+
Parsed subcommands can be retrieved from model instances using the `get_subcommand` utility function. If a subcommand is
751+
not required, set the `is_required` flag to `False` to disable raising an error if no subcommand is found.
752+
750753
!!! note
751754
CLI settings subcommands are limited to a single subparser per model. In other words, all subcommands for a model
752755
are grouped under a single subparser; it does not allow for multiple subparsers with each subparser having its own
@@ -759,114 +762,59 @@ subcommands must be a valid type derived from either a pydantic `BaseModel` or p
759762
```py
760763
import sys
761764

762-
from pydantic import BaseModel, Field
763-
from pydantic.dataclasses import dataclass
765+
from pydantic import BaseModel
764766

765767
from pydantic_settings import (
766768
BaseSettings,
767769
CliPositionalArg,
768770
CliSubCommand,
771+
SettingsError,
772+
get_subcommand,
769773
)
770774

771775

772-
@dataclass
773-
class FooPlugin:
774-
"""git-plugins-foo - Extra deep foo plugin command"""
775-
776-
x_feature: bool = Field(default=False, description='Enable "X" feature')
777-
778-
779-
@dataclass
780-
class BarPlugin:
781-
"""git-plugins-bar - Extra deep bar plugin command"""
782-
783-
y_feature: bool = Field(default=False, description='Enable "Y" feature')
784-
785-
786-
@dataclass
787-
class Plugins:
788-
"""git-plugins - Fake plugins for GIT"""
789-
790-
foo: CliSubCommand[FooPlugin] = Field(description='Foo is fake plugin')
791-
792-
bar: CliSubCommand[BarPlugin] = Field(description='Bar is fake plugin')
776+
class Init(BaseModel):
777+
directory: CliPositionalArg[str]
793778

794779

795780
class Clone(BaseModel):
796-
"""git-clone - Clone a repository into a new directory"""
797-
798-
repository: CliPositionalArg[str] = Field(description='The repo ...')
799-
800-
directory: CliPositionalArg[str] = Field(description='The dir ...')
801-
802-
local: bool = Field(default=False, description='When the repo ...')
803-
804-
805-
class Git(BaseSettings, cli_parse_args=True, cli_prog_name='git'):
806-
"""git - The stupid content tracker"""
807-
808-
clone: CliSubCommand[Clone] = Field(description='Clone a repo ...')
809-
810-
plugins: CliSubCommand[Plugins] = Field(description='Fake GIT plugins')
811-
812-
813-
try:
814-
sys.argv = ['example.py', '--help']
815-
Git()
816-
except SystemExit as e:
817-
print(e)
818-
#> 0
819-
"""
820-
usage: git [-h] {clone,plugins} ...
781+
repository: CliPositionalArg[str]
782+
directory: CliPositionalArg[str]
821783

822-
git - The stupid content tracker
823784

824-
options:
825-
-h, --help show this help message and exit
785+
class Git(BaseSettings, cli_parse_args=True, cli_exit_on_error=False):
786+
clone: CliSubCommand[Clone]
787+
init: CliSubCommand[Init]
826788

827-
subcommands:
828-
{clone,plugins}
829-
clone Clone a repo ...
830-
plugins Fake GIT plugins
831-
"""
832789

790+
# Run without subcommands
791+
sys.argv = ['example.py']
792+
cmd = Git()
793+
assert cmd.model_dump() == {'clone': None, 'init': None}
833794

834795
try:
835-
sys.argv = ['example.py', 'clone', '--help']
836-
Git()
837-
except SystemExit as e:
838-
print(e)
839-
#> 0
840-
"""
841-
usage: git clone [-h] [--local bool] [--shared bool] REPOSITORY DIRECTORY
796+
# Will raise an error since no subcommand was provided
797+
get_subcommand(cmd).model_dump()
798+
except SettingsError as err:
799+
assert str(err) == 'Error: CLI subcommand is required {clone, init}'
842800

843-
git-clone - Clone a repository into a new directory
801+
# Will not raise an error since subcommand is not required
802+
assert get_subcommand(cmd, is_required=False) is None
844803

845-
positional arguments:
846-
REPOSITORY The repo ...
847-
DIRECTORY The dir ...
848-
849-
options:
850-
-h, --help show this help message and exit
851-
--local bool When the repo ... (default: False)
852-
"""
853804

805+
# Run the clone subcommand
806+
sys.argv = ['example.py', 'clone', 'repo', 'dest']
807+
cmd = Git()
808+
assert cmd.model_dump() == {
809+
'clone': {'repository': 'repo', 'directory': 'dest'},
810+
'init': None,
811+
}
854812

855-
try:
856-
sys.argv = ['example.py', 'plugins', 'bar', '--help']
857-
Git()
858-
except SystemExit as e:
859-
print(e)
860-
#> 0
861-
"""
862-
usage: git plugins bar [-h] [--my_feature bool]
863-
864-
git-plugins-bar - Extra deep bar plugin command
865-
866-
options:
867-
-h, --help show this help message and exit
868-
--y_feature bool Enable "Y" feature (default: False)
869-
"""
813+
# Returns the subcommand model instance (in this case, 'clone')
814+
assert get_subcommand(cmd).model_dump() == {
815+
'directory': 'dest',
816+
'repository': 'repo',
817+
}
870818
```
871819

872820
### Customizing the CLI Experience

pydantic_settings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
SettingsError,
1717
TomlConfigSettingsSource,
1818
YamlConfigSettingsSource,
19+
get_subcommand,
1920
)
2021
from .version import VERSION
2122

@@ -38,6 +39,7 @@
3839
'TomlConfigSettingsSource',
3940
'YamlConfigSettingsSource',
4041
'AzureKeyVaultSettingsSource',
42+
'get_subcommand',
4143
'__version__',
4244
)
4345

pydantic_settings/sources.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,53 @@ def error(self, message: str) -> NoReturn:
156156
CliExplicitFlag = Annotated[_CliBoolFlag, _CliExplicitFlag]
157157

158158

159+
def get_subcommand(model: BaseModel, is_required: bool = True, cli_exit_on_error: bool | None = None) -> Any:
160+
"""
161+
Get the subcommand from a model.
162+
163+
Args:
164+
model: The model to get the subcommand from.
165+
is_required: Determines whether a model must have subcommand set and raises error if not
166+
found. Defaults to `True`.
167+
cli_exit_on_error: Determines whether this function exits with error if no subcommand is found.
168+
Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`.
169+
170+
Returns:
171+
The subcommand model if found, otherwise `None`.
172+
173+
Raises:
174+
SystemExit: When no subcommand is found and is_required=`True` and cli_exit_on_error=`True`
175+
(the default).
176+
SettingsError: When no subcommand is found and is_required=`True` and
177+
cli_exit_on_error=`False`.
178+
"""
179+
180+
model_cls = type(model)
181+
if cli_exit_on_error is None and is_model_class(model_cls):
182+
model_default = model.model_config.get('cli_exit_on_error')
183+
if isinstance(model_default, bool):
184+
cli_exit_on_error = model_default
185+
if cli_exit_on_error is None:
186+
cli_exit_on_error = True
187+
188+
subcommands: list[str] = []
189+
for field_name, field_info in _get_model_fields(model_cls).items():
190+
if _CliSubCommand in field_info.metadata:
191+
if getattr(model, field_name) is not None:
192+
return getattr(model, field_name)
193+
subcommands.append(field_name)
194+
195+
if is_required:
196+
error_message = (
197+
f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}'
198+
if subcommands
199+
else 'Error: CLI subcommand is required but no subcommands were found.'
200+
)
201+
raise SystemExit(error_message) if cli_exit_on_error else SettingsError(error_message)
202+
203+
return None
204+
205+
159206
class EnvNoneType(str):
160207
pass
161208

@@ -763,11 +810,7 @@ class Cfg(BaseSettings):
763810
if type_has_key:
764811
return type_has_key
765812
elif is_model_class(annotation) or is_pydantic_dataclass(annotation):
766-
fields = (
767-
annotation.__pydantic_fields__
768-
if is_pydantic_dataclass(annotation) and hasattr(annotation, '__pydantic_fields__')
769-
else cast(BaseModel, annotation).model_fields
770-
)
813+
fields = _get_model_fields(annotation)
771814
# `case_sensitive is None` is here to be compatible with the old behavior.
772815
# Has to be removed in V3.
773816
if (case_sensitive is None or case_sensitive) and fields.get(key):
@@ -1376,12 +1419,7 @@ def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str,
13761419

13771420
def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]:
13781421
positional_args, subcommand_args, optional_args = [], [], []
1379-
fields = (
1380-
model.__pydantic_fields__
1381-
if hasattr(model, '__pydantic_fields__') and is_pydantic_dataclass(model)
1382-
else model.model_fields
1383-
)
1384-
for field_name, field_info in fields.items():
1422+
for field_name, field_info in _get_model_fields(model).items():
13851423
if _CliSubCommand in field_info.metadata:
13861424
if not field_info.is_required():
13871425
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value')
@@ -1496,9 +1534,7 @@ def _add_parser_args(
14961534
sub_models: list[type[BaseModel]] = self._get_sub_models(model, field_name, field_info)
14971535
if _CliSubCommand in field_info.metadata:
14981536
if subparsers is None:
1499-
subparsers = self._add_subparsers(
1500-
parser, title='subcommands', dest=f'{arg_prefix}:subcommand', required=self.cli_enforce_required
1501-
)
1537+
subparsers = self._add_subparsers(parser, title='subcommands', dest=f'{arg_prefix}:subcommand')
15021538
self._cli_subcommands[f'{arg_prefix}:subcommand'] = [f'{arg_prefix}{field_name}']
15031539
else:
15041540
self._cli_subcommands[f'{arg_prefix}:subcommand'].append(f'{arg_prefix}{field_name}')
@@ -2095,5 +2131,13 @@ def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any
20952131
return None
20962132

20972133

2134+
def _get_model_fields(model_cls: type[Any]) -> dict[str, FieldInfo]:
2135+
if is_pydantic_dataclass(model_cls) and hasattr(model_cls, '__pydantic_fields__'):
2136+
return model_cls.__pydantic_fields__
2137+
if is_model_class(model_cls):
2138+
return model_cls.model_fields
2139+
raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass')
2140+
2141+
20982142
def _is_function(obj: Any) -> bool:
20992143
return inspect.isfunction(obj) or inspect.isbuiltin(obj) or inspect.isroutine(obj) or inspect.ismethod(obj)

tests/test_settings.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
AliasChoices,
2121
AliasPath,
2222
BaseModel,
23+
ConfigDict,
2324
DirectoryPath,
2425
Discriminator,
2526
Field,
@@ -59,6 +60,7 @@
5960
CliSubCommand,
6061
DefaultSettingsSource,
6162
SettingsError,
63+
get_subcommand,
6264
)
6365

6466
try:
@@ -3095,6 +3097,12 @@ class FooPlugin:
30953097
class BarPlugin:
30963098
my_feature: bool = False
30973099

3100+
bar = BarPlugin()
3101+
with pytest.raises(SystemExit, match='Error: CLI subcommand is required but no subcommands were found.'):
3102+
get_subcommand(bar)
3103+
with pytest.raises(SettingsError, match='Error: CLI subcommand is required but no subcommands were found.'):
3104+
get_subcommand(bar, cli_exit_on_error=False)
3105+
30983106
@pydantic_dataclasses.dataclass
30993107
class Plugins:
31003108
foo: CliSubCommand[FooPlugin]
@@ -3116,26 +3124,62 @@ class Git(BaseSettings):
31163124
init: CliSubCommand[Init]
31173125
plugins: CliSubCommand[Plugins]
31183126

3127+
git = Git(_cli_parse_args=[])
3128+
assert git.model_dump() == {
3129+
'clone': None,
3130+
'init': None,
3131+
'plugins': None,
3132+
}
3133+
assert get_subcommand(git, is_required=False) is None
3134+
with pytest.raises(SystemExit, match='Error: CLI subcommand is required {clone, init, plugins}'):
3135+
get_subcommand(git)
3136+
with pytest.raises(SettingsError, match='Error: CLI subcommand is required {clone, init, plugins}'):
3137+
get_subcommand(git, cli_exit_on_error=False)
3138+
31193139
git = Git(_cli_parse_args=['init', '--quiet', 'true', 'dir/path'])
31203140
assert git.model_dump() == {
31213141
'clone': None,
31223142
'init': {'directory': 'dir/path', 'quiet': True, 'bare': False},
31233143
'plugins': None,
31243144
}
3145+
assert get_subcommand(git) == git.init
3146+
assert get_subcommand(git, is_required=False) == git.init
31253147

31263148
git = Git(_cli_parse_args=['clone', 'repo', '.', '--shared', 'true'])
31273149
assert git.model_dump() == {
31283150
'clone': {'repository': 'repo', 'directory': '.', 'local': False, 'shared': True},
31293151
'init': None,
31303152
'plugins': None,
31313153
}
3154+
assert get_subcommand(git) == git.clone
3155+
assert get_subcommand(git, is_required=False) == git.clone
31323156

31333157
git = Git(_cli_parse_args=['plugins', 'bar'])
31343158
assert git.model_dump() == {
31353159
'clone': None,
31363160
'init': None,
31373161
'plugins': {'foo': None, 'bar': {'my_feature': False}},
31383162
}
3163+
assert get_subcommand(git) == git.plugins
3164+
assert get_subcommand(git, is_required=False) == git.plugins
3165+
assert get_subcommand(get_subcommand(git)) == git.plugins.bar
3166+
assert get_subcommand(get_subcommand(git), is_required=False) == git.plugins.bar
3167+
3168+
class NotModel: ...
3169+
3170+
with pytest.raises(
3171+
SettingsError, match='Error: NotModel is not subclass of BaseModel or pydantic.dataclasses.dataclass'
3172+
):
3173+
get_subcommand(NotModel())
3174+
3175+
class NotSettingsConfigDict(BaseModel):
3176+
model_config = ConfigDict(cli_exit_on_error='not a bool')
3177+
3178+
with pytest.raises(SystemExit, match='Error: CLI subcommand is required but no subcommands were found.'):
3179+
get_subcommand(NotSettingsConfigDict())
3180+
3181+
with pytest.raises(SettingsError, match='Error: CLI subcommand is required but no subcommands were found.'):
3182+
get_subcommand(NotSettingsConfigDict(), cli_exit_on_error=False)
31393183

31403184

31413185
def test_cli_union_similar_sub_models():

0 commit comments

Comments
 (0)