8
8
import typing
9
9
import warnings
10
10
from abc import ABC , abstractmethod
11
+
12
+ if sys .version_info >= (3 , 9 ):
13
+ from argparse import BooleanOptionalAction
11
14
from argparse import SUPPRESS , ArgumentParser , Namespace , RawDescriptionHelpFormatter , _SubParsersAction
12
15
from collections import deque
13
16
from dataclasses import is_dataclass
@@ -124,6 +127,14 @@ class _CliPositionalArg:
124
127
pass
125
128
126
129
130
+ class _CliImplicitFlag :
131
+ pass
132
+
133
+
134
+ class _CliExplicitFlag :
135
+ pass
136
+
137
+
127
138
class _CliInternalArgParser (ArgumentParser ):
128
139
def __init__ (self , cli_exit_on_error : bool = True , ** kwargs : Any ) -> None :
129
140
super ().__init__ (** kwargs )
@@ -138,6 +149,9 @@ def error(self, message: str) -> NoReturn:
138
149
T = TypeVar ('T' )
139
150
CliSubCommand = Annotated [Union [T , None ], _CliSubCommand ]
140
151
CliPositionalArg = Annotated [T , _CliPositionalArg ]
152
+ _CliBoolFlag = TypeVar ('_CliBoolFlag' , bound = bool )
153
+ CliImplicitFlag = Annotated [_CliBoolFlag , _CliImplicitFlag ]
154
+ CliExplicitFlag = Annotated [_CliBoolFlag , _CliExplicitFlag ]
141
155
142
156
143
157
class EnvNoneType (str ):
@@ -905,6 +919,8 @@ class CliSettingsSource(EnvSettingsSource, Generic[T]):
905
919
cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs.
906
920
Defaults to `True`.
907
921
cli_prefix: Prefix for command line arguments added under the root parser. Defaults to "".
922
+ cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags.
923
+ (e.g. --flag, --no-flag). Defaults to `False`.
908
924
case_sensitive: Whether CLI "--arg" names should be read with case-sensitivity. Defaults to `True`.
909
925
Note: Case-insensitive matching is only supported on the internal root parser and does not apply to CLI
910
926
subcommands.
@@ -932,6 +948,7 @@ def __init__(
932
948
cli_use_class_docs_for_groups : bool | None = None ,
933
949
cli_exit_on_error : bool | None = None ,
934
950
cli_prefix : str | None = None ,
951
+ cli_implicit_flags : bool | None = None ,
935
952
case_sensitive : bool | None = True ,
936
953
root_parser : Any = None ,
937
954
parse_args_method : Callable [..., Any ] | None = ArgumentParser .parse_args ,
@@ -975,6 +992,11 @@ def __init__(
975
992
if cli_prefix .startswith ('.' ) or cli_prefix .endswith ('.' ) or not cli_prefix .replace ('.' , '' ).isidentifier (): # type: ignore
976
993
raise SettingsError (f'CLI settings source prefix is invalid: { cli_prefix } ' )
977
994
self .cli_prefix += '.'
995
+ self .cli_implicit_flags = (
996
+ cli_implicit_flags
997
+ if cli_implicit_flags is not None
998
+ else settings_cls .model_config .get ('cli_implicit_flags' , False )
999
+ )
978
1000
979
1001
case_sensitive = case_sensitive if case_sensitive is not None else True
980
1002
if not case_sensitive and root_parser is not None :
@@ -1281,6 +1303,23 @@ def _get_resolved_names(
1281
1303
resolved_names = [resolved_name .lower () for resolved_name in resolved_names ]
1282
1304
return tuple (dict .fromkeys (resolved_names )), is_alias_path_only
1283
1305
1306
+ def _verify_cli_flag_annotations (self , model : type [BaseModel ], field_name : str , field_info : FieldInfo ) -> None :
1307
+ if _CliImplicitFlag in field_info .metadata :
1308
+ cli_flag_name = 'CliImplicitFlag'
1309
+ elif _CliExplicitFlag in field_info .metadata :
1310
+ cli_flag_name = 'CliExplicitFlag'
1311
+ else :
1312
+ return
1313
+
1314
+ if field_info .annotation is not bool :
1315
+ raise SettingsError (f'{ cli_flag_name } argument { model .__name__ } .{ field_name } is not of type bool' )
1316
+ elif sys .version_info < (3 , 9 ) and (
1317
+ field_info .default is PydanticUndefined and field_info .default_factory is None
1318
+ ):
1319
+ raise SettingsError (
1320
+ f'{ cli_flag_name } argument { model .__name__ } .{ field_name } must have default for python versions < 3.9'
1321
+ )
1322
+
1284
1323
def _sort_arg_fields (self , model : type [BaseModel ]) -> list [tuple [str , FieldInfo ]]:
1285
1324
positional_args , subcommand_args , optional_args = [], [], []
1286
1325
fields = (
@@ -1310,6 +1349,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
1310
1349
raise SettingsError (f'positional argument { model .__name__ } .{ field_name } has an alias' )
1311
1350
positional_args .append ((field_name , field_info ))
1312
1351
else :
1352
+ self ._verify_cli_flag_annotations (model , field_name , field_info )
1313
1353
optional_args .append ((field_name , field_info ))
1314
1354
return positional_args + subcommand_args + optional_args
1315
1355
@@ -1457,6 +1497,8 @@ def _add_parser_args(
1457
1497
del kwargs ['required' ]
1458
1498
arg_flag = ''
1459
1499
1500
+ self ._convert_bool_flag (kwargs , field_info , model_default )
1501
+
1460
1502
if sub_models and kwargs .get ('action' ) != 'append' :
1461
1503
self ._add_parser_submodels (
1462
1504
parser ,
@@ -1486,6 +1528,22 @@ def _add_parser_args(
1486
1528
self ._add_parser_alias_paths (parser , alias_path_args , added_args , arg_prefix , subcommand_prefix , group )
1487
1529
return parser
1488
1530
1531
+ def _convert_bool_flag (self , kwargs : dict [str , Any ], field_info : FieldInfo , model_default : Any ) -> None :
1532
+ if kwargs ['metavar' ] == 'bool' :
1533
+ default = None
1534
+ if field_info .default is not PydanticUndefined :
1535
+ default = field_info .default
1536
+ if model_default is not PydanticUndefined :
1537
+ default = model_default
1538
+ if sys .version_info >= (3 , 9 ) or isinstance (default , bool ):
1539
+ if (self .cli_implicit_flags or _CliImplicitFlag in field_info .metadata ) and (
1540
+ _CliExplicitFlag not in field_info .metadata
1541
+ ):
1542
+ del kwargs ['metavar' ]
1543
+ kwargs ['action' ] = (
1544
+ BooleanOptionalAction if sys .version_info >= (3 , 9 ) else f'store_{ str (not default ).lower ()} '
1545
+ )
1546
+
1489
1547
def _get_arg_names (
1490
1548
self , arg_prefix : str , subcommand_prefix : str , alias_prefixes : list [str ], resolved_names : tuple [str , ...]
1491
1549
) -> list [str ]:
0 commit comments