Skip to content

Commit 136c2c7

Browse files
authored
Add CLI bool flags (#365)
1 parent 8fb9abb commit 136c2c7

File tree

5 files changed

+215
-1
lines changed

5 files changed

+215
-1
lines changed

docs/index.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,75 @@ options:
868868
"""
869869
```
870870

871+
#### CLI Boolean Flags
872+
873+
Change whether boolean fields should be explicit or implicit by default using the `cli_implicit_flags` setting. By
874+
default, boolean fields are "explicit", meaning a boolean value must be explicitly provided on the CLI, e.g.
875+
`--flag=True`. Conversely, boolean fields that are "implicit" derive the value from the flag itself, e.g.
876+
`--flag,--no-flag`, which removes the need for an explicit value to be passed.
877+
878+
Additionally, the provided `CliImplicitFlag` and `CliExplicitFlag` annotations can be used for more granular control
879+
when necessary.
880+
881+
!!! note
882+
For `python < 3.9`:
883+
* The `--no-flag` option is not generated due to an underlying `argparse` limitation.
884+
* The `CliImplicitFlag` and `CliExplicitFlag` annotations can only be applied to optional bool fields.
885+
886+
```py
887+
from pydantic_settings import BaseSettings, CliExplicitFlag, CliImplicitFlag
888+
889+
890+
class ExplicitSettings(BaseSettings, cli_parse_args=True):
891+
"""Boolean fields are explicit by default."""
892+
893+
explicit_req: bool
894+
"""
895+
--explicit_req bool (required)
896+
"""
897+
898+
explicit_opt: bool = False
899+
"""
900+
--explicit_opt bool (default: False)
901+
"""
902+
903+
# Booleans are explicit by default, so must override implicit flags with annotation
904+
implicit_req: CliImplicitFlag[bool]
905+
"""
906+
--implicit_req, --no-implicit_req (required)
907+
"""
908+
909+
implicit_opt: CliImplicitFlag[bool] = False
910+
"""
911+
--implicit_opt, --no-implicit_opt (default: False)
912+
"""
913+
914+
915+
class ImplicitSettings(BaseSettings, cli_parse_args=True, cli_implicit_flags=True):
916+
"""With cli_implicit_flags=True, boolean fields are implicit by default."""
917+
918+
# Booleans are implicit by default, so must override explicit flags with annotation
919+
explicit_req: CliExplicitFlag[bool]
920+
"""
921+
--explicit_req bool (required)
922+
"""
923+
924+
explicit_opt: CliExplicitFlag[bool] = False
925+
"""
926+
--explicit_opt bool (default: False)
927+
"""
928+
929+
implicit_req: bool
930+
"""
931+
--implicit_req, --no-implicit_req (required)
932+
"""
933+
934+
implicit_opt: bool = False
935+
"""
936+
--implicit_opt, --no-implicit_opt (default: False)
937+
"""
938+
```
939+
871940
#### Change Whether CLI Should Exit on Error
872941

873942
Change whether the CLI internal parser will exit on error or raise a `SettingsError` exception by using

pydantic_settings/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .main import BaseSettings, SettingsConfigDict
22
from .sources import (
33
AzureKeyVaultSettingsSource,
4+
CliExplicitFlag,
5+
CliImplicitFlag,
46
CliPositionalArg,
57
CliSettingsSource,
68
CliSubCommand,
@@ -24,6 +26,8 @@
2426
'CliSettingsSource',
2527
'CliSubCommand',
2628
'CliPositionalArg',
29+
'CliExplicitFlag',
30+
'CliImplicitFlag',
2731
'InitSettingsSource',
2832
'JsonConfigSettingsSource',
2933
'PyprojectTomlConfigSettingsSource',

pydantic_settings/main.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class SettingsConfigDict(ConfigDict, total=False):
4040
cli_use_class_docs_for_groups: bool
4141
cli_exit_on_error: bool
4242
cli_prefix: str
43+
cli_implicit_flags: bool | None
4344
secrets_dir: str | Path | None
4445
json_file: PathType | None
4546
json_file_encoding: str | None
@@ -114,6 +115,8 @@ class BaseSettings(BaseModel):
114115
_cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs.
115116
Defaults to `True`.
116117
_cli_prefix: The root parser command line arguments prefix. Defaults to "".
118+
_cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags.
119+
(e.g. --flag, --no-flag). Defaults to `False`.
117120
_secrets_dir: The secret files directory. Defaults to `None`.
118121
"""
119122

@@ -137,6 +140,7 @@ def __init__(
137140
_cli_use_class_docs_for_groups: bool | None = None,
138141
_cli_exit_on_error: bool | None = None,
139142
_cli_prefix: str | None = None,
143+
_cli_implicit_flags: bool | None = None,
140144
_secrets_dir: str | Path | None = None,
141145
**values: Any,
142146
) -> None:
@@ -162,6 +166,7 @@ def __init__(
162166
_cli_use_class_docs_for_groups=_cli_use_class_docs_for_groups,
163167
_cli_exit_on_error=_cli_exit_on_error,
164168
_cli_prefix=_cli_prefix,
169+
_cli_implicit_flags=_cli_implicit_flags,
165170
_secrets_dir=_secrets_dir,
166171
)
167172
)
@@ -211,6 +216,7 @@ def _settings_build_values(
211216
_cli_use_class_docs_for_groups: bool | None = None,
212217
_cli_exit_on_error: bool | None = None,
213218
_cli_prefix: str | None = None,
219+
_cli_implicit_flags: bool | None = None,
214220
_secrets_dir: str | Path | None = None,
215221
) -> dict[str, Any]:
216222
# Determine settings config values
@@ -260,6 +266,9 @@ def _settings_build_values(
260266
_cli_exit_on_error if _cli_exit_on_error is not None else self.model_config.get('cli_exit_on_error')
261267
)
262268
cli_prefix = _cli_prefix if _cli_prefix is not None else self.model_config.get('cli_prefix')
269+
cli_implicit_flags = (
270+
_cli_implicit_flags if _cli_implicit_flags is not None else self.model_config.get('cli_implicit_flags')
271+
)
263272

264273
secrets_dir = _secrets_dir if _secrets_dir is not None else self.model_config.get('secrets_dir')
265274

@@ -311,6 +320,7 @@ def _settings_build_values(
311320
cli_use_class_docs_for_groups=cli_use_class_docs_for_groups,
312321
cli_exit_on_error=cli_exit_on_error,
313322
cli_prefix=cli_prefix,
323+
cli_implicit_flags=cli_implicit_flags,
314324
case_sensitive=case_sensitive,
315325
)
316326
if cli_settings_source is None
@@ -358,6 +368,7 @@ def _settings_build_values(
358368
cli_use_class_docs_for_groups=False,
359369
cli_exit_on_error=True,
360370
cli_prefix='',
371+
cli_implicit_flags=False,
361372
json_file=None,
362373
json_file_encoding=None,
363374
yaml_file=None,

pydantic_settings/sources.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
import typing
99
import warnings
1010
from abc import ABC, abstractmethod
11+
12+
if sys.version_info >= (3, 9):
13+
from argparse import BooleanOptionalAction
1114
from argparse import SUPPRESS, ArgumentParser, Namespace, RawDescriptionHelpFormatter, _SubParsersAction
1215
from collections import deque
1316
from dataclasses import is_dataclass
@@ -124,6 +127,14 @@ class _CliPositionalArg:
124127
pass
125128

126129

130+
class _CliImplicitFlag:
131+
pass
132+
133+
134+
class _CliExplicitFlag:
135+
pass
136+
137+
127138
class _CliInternalArgParser(ArgumentParser):
128139
def __init__(self, cli_exit_on_error: bool = True, **kwargs: Any) -> None:
129140
super().__init__(**kwargs)
@@ -138,6 +149,9 @@ def error(self, message: str) -> NoReturn:
138149
T = TypeVar('T')
139150
CliSubCommand = Annotated[Union[T, None], _CliSubCommand]
140151
CliPositionalArg = Annotated[T, _CliPositionalArg]
152+
_CliBoolFlag = TypeVar('_CliBoolFlag', bound=bool)
153+
CliImplicitFlag = Annotated[_CliBoolFlag, _CliImplicitFlag]
154+
CliExplicitFlag = Annotated[_CliBoolFlag, _CliExplicitFlag]
141155

142156

143157
class EnvNoneType(str):
@@ -905,6 +919,8 @@ class CliSettingsSource(EnvSettingsSource, Generic[T]):
905919
cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs.
906920
Defaults to `True`.
907921
cli_prefix: Prefix for command line arguments added under the root parser. Defaults to "".
922+
cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags.
923+
(e.g. --flag, --no-flag). Defaults to `False`.
908924
case_sensitive: Whether CLI "--arg" names should be read with case-sensitivity. Defaults to `True`.
909925
Note: Case-insensitive matching is only supported on the internal root parser and does not apply to CLI
910926
subcommands.
@@ -932,6 +948,7 @@ def __init__(
932948
cli_use_class_docs_for_groups: bool | None = None,
933949
cli_exit_on_error: bool | None = None,
934950
cli_prefix: str | None = None,
951+
cli_implicit_flags: bool | None = None,
935952
case_sensitive: bool | None = True,
936953
root_parser: Any = None,
937954
parse_args_method: Callable[..., Any] | None = ArgumentParser.parse_args,
@@ -975,6 +992,11 @@ def __init__(
975992
if cli_prefix.startswith('.') or cli_prefix.endswith('.') or not cli_prefix.replace('.', '').isidentifier(): # type: ignore
976993
raise SettingsError(f'CLI settings source prefix is invalid: {cli_prefix}')
977994
self.cli_prefix += '.'
995+
self.cli_implicit_flags = (
996+
cli_implicit_flags
997+
if cli_implicit_flags is not None
998+
else settings_cls.model_config.get('cli_implicit_flags', False)
999+
)
9781000

9791001
case_sensitive = case_sensitive if case_sensitive is not None else True
9801002
if not case_sensitive and root_parser is not None:
@@ -1281,6 +1303,23 @@ def _get_resolved_names(
12811303
resolved_names = [resolved_name.lower() for resolved_name in resolved_names]
12821304
return tuple(dict.fromkeys(resolved_names)), is_alias_path_only
12831305

1306+
def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None:
1307+
if _CliImplicitFlag in field_info.metadata:
1308+
cli_flag_name = 'CliImplicitFlag'
1309+
elif _CliExplicitFlag in field_info.metadata:
1310+
cli_flag_name = 'CliExplicitFlag'
1311+
else:
1312+
return
1313+
1314+
if field_info.annotation is not bool:
1315+
raise SettingsError(f'{cli_flag_name} argument {model.__name__}.{field_name} is not of type bool')
1316+
elif sys.version_info < (3, 9) and (
1317+
field_info.default is PydanticUndefined and field_info.default_factory is None
1318+
):
1319+
raise SettingsError(
1320+
f'{cli_flag_name} argument {model.__name__}.{field_name} must have default for python versions < 3.9'
1321+
)
1322+
12841323
def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]:
12851324
positional_args, subcommand_args, optional_args = [], [], []
12861325
fields = (
@@ -1310,6 +1349,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
13101349
raise SettingsError(f'positional argument {model.__name__}.{field_name} has an alias')
13111350
positional_args.append((field_name, field_info))
13121351
else:
1352+
self._verify_cli_flag_annotations(model, field_name, field_info)
13131353
optional_args.append((field_name, field_info))
13141354
return positional_args + subcommand_args + optional_args
13151355

@@ -1457,6 +1497,8 @@ def _add_parser_args(
14571497
del kwargs['required']
14581498
arg_flag = ''
14591499

1500+
self._convert_bool_flag(kwargs, field_info, model_default)
1501+
14601502
if sub_models and kwargs.get('action') != 'append':
14611503
self._add_parser_submodels(
14621504
parser,
@@ -1486,6 +1528,22 @@ def _add_parser_args(
14861528
self._add_parser_alias_paths(parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group)
14871529
return parser
14881530

1531+
def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, model_default: Any) -> None:
1532+
if kwargs['metavar'] == 'bool':
1533+
default = None
1534+
if field_info.default is not PydanticUndefined:
1535+
default = field_info.default
1536+
if model_default is not PydanticUndefined:
1537+
default = model_default
1538+
if sys.version_info >= (3, 9) or isinstance(default, bool):
1539+
if (self.cli_implicit_flags or _CliImplicitFlag in field_info.metadata) and (
1540+
_CliExplicitFlag not in field_info.metadata
1541+
):
1542+
del kwargs['metavar']
1543+
kwargs['action'] = (
1544+
BooleanOptionalAction if sys.version_info >= (3, 9) else f'store_{str(not default).lower()}'
1545+
)
1546+
14891547
def _get_arg_names(
14901548
self, arg_prefix: str, subcommand_prefix: str, alias_prefixes: list[str], resolved_names: tuple[str, ...]
14911549
) -> list[str]:

tests/test_settings.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,14 @@
5050
TomlConfigSettingsSource,
5151
YamlConfigSettingsSource,
5252
)
53-
from pydantic_settings.sources import CliPositionalArg, CliSettingsSource, CliSubCommand, SettingsError
53+
from pydantic_settings.sources import (
54+
CliExplicitFlag,
55+
CliImplicitFlag,
56+
CliPositionalArg,
57+
CliSettingsSource,
58+
CliSubCommand,
59+
SettingsError,
60+
)
5461

5562
try:
5663
import dotenv
@@ -3119,6 +3126,71 @@ class InvalidCliParseArgsType(BaseSettings, cli_parse_args='invalid type'):
31193126

31203127
InvalidCliParseArgsType()
31213128

3129+
with pytest.raises(SettingsError, match='CliExplicitFlag argument CliFlagNotBool.flag is not of type bool'):
3130+
3131+
class CliFlagNotBool(BaseSettings, cli_parse_args=True):
3132+
flag: CliExplicitFlag[int] = False
3133+
3134+
CliFlagNotBool()
3135+
3136+
if sys.version_info < (3, 9):
3137+
with pytest.raises(
3138+
SettingsError,
3139+
match='CliImplicitFlag argument CliFlag38NotOpt.flag must have default for python versions < 3.9',
3140+
):
3141+
3142+
class CliFlag38NotOpt(BaseSettings, cli_parse_args=True):
3143+
flag: CliImplicitFlag[bool]
3144+
3145+
CliFlag38NotOpt()
3146+
3147+
3148+
@pytest.mark.parametrize('enforce_required', [True, False])
3149+
def test_cli_bool_flags(monkeypatch, enforce_required):
3150+
if sys.version_info < (3, 9):
3151+
3152+
class ExplicitSettings(BaseSettings, cli_enforce_required=enforce_required):
3153+
explicit_req: bool
3154+
explicit_opt: bool = False
3155+
implicit_opt: CliImplicitFlag[bool] = False
3156+
3157+
class ImplicitSettings(BaseSettings, cli_implicit_flags=True, cli_enforce_required=enforce_required):
3158+
explicit_req: bool
3159+
explicit_opt: CliExplicitFlag[bool] = False
3160+
implicit_opt: bool = False
3161+
3162+
expected = {
3163+
'explicit_req': True,
3164+
'explicit_opt': False,
3165+
'implicit_opt': False,
3166+
}
3167+
3168+
assert ExplicitSettings(_cli_parse_args=['--explicit_req=True']).model_dump() == expected
3169+
assert ImplicitSettings(_cli_parse_args=['--explicit_req=True']).model_dump() == expected
3170+
else:
3171+
3172+
class ExplicitSettings(BaseSettings, cli_enforce_required=enforce_required):
3173+
explicit_req: bool
3174+
explicit_opt: bool = False
3175+
implicit_req: CliImplicitFlag[bool]
3176+
implicit_opt: CliImplicitFlag[bool] = False
3177+
3178+
class ImplicitSettings(BaseSettings, cli_implicit_flags=True, cli_enforce_required=enforce_required):
3179+
explicit_req: CliExplicitFlag[bool]
3180+
explicit_opt: CliExplicitFlag[bool] = False
3181+
implicit_req: bool
3182+
implicit_opt: bool = False
3183+
3184+
expected = {
3185+
'explicit_req': True,
3186+
'explicit_opt': False,
3187+
'implicit_req': True,
3188+
'implicit_opt': False,
3189+
}
3190+
3191+
assert ExplicitSettings(_cli_parse_args=['--explicit_req=True', '--implicit_req']).model_dump() == expected
3192+
assert ImplicitSettings(_cli_parse_args=['--explicit_req=True', '--implicit_req']).model_dump() == expected
3193+
31223194

31233195
def test_cli_avoid_json(capsys, monkeypatch):
31243196
class SubModel(BaseModel):

0 commit comments

Comments
 (0)