35
35
)
36
36
37
37
import typing_extensions
38
- from pydantic import AliasChoices , AliasPath , BaseModel , Field
38
+ from pydantic import AliasChoices , AliasPath , BaseModel , Field , create_model
39
39
from pydantic ._internal ._repr import Representation
40
40
from pydantic ._internal ._utils import is_model_class
41
41
from pydantic .dataclasses import is_pydantic_dataclass
47
47
48
48
from ...exceptions import SettingsError
49
49
from ...utils import _lenient_issubclass , _WithArgsTypes
50
- from ..types import NoDecode , _CliExplicitFlag , _CliImplicitFlag , _CliPositionalArg , _CliSubCommand , _CliUnknownArgs
50
+ from ..types import (
51
+ NoDecode ,
52
+ PydanticModel ,
53
+ _CliExplicitFlag ,
54
+ _CliImplicitFlag ,
55
+ _CliPositionalArg ,
56
+ _CliSubCommand ,
57
+ _CliUnknownArgs ,
58
+ )
51
59
from ..utils import (
52
60
_annotation_contains_types ,
53
61
_annotation_enum_val_to_name ,
@@ -74,10 +82,6 @@ def error(self, message: str) -> NoReturn:
74
82
super ().error (message )
75
83
76
84
77
- class _CliInternalArgSerializer (_CliInternalArgParser ):
78
- pass
79
-
80
-
81
85
class CliMutuallyExclusiveGroup (BaseModel ):
82
86
pass
83
87
@@ -666,8 +670,6 @@ def _parse_known_args(*args: Any, **kwargs: Any) -> Namespace:
666
670
self ._formatter_class = formatter_class
667
671
self ._cli_dict_args : dict [str , type [Any ] | None ] = {}
668
672
self ._cli_subcommands : defaultdict [str , dict [str , str ]] = defaultdict (dict )
669
- self ._is_serialize_args = isinstance (root_parser , _CliInternalArgSerializer )
670
- self ._serialize_positional_args : dict [str , Any ] = {}
671
673
self ._add_parser_args (
672
674
parser = self .root_parser ,
673
675
model = self .settings_cls ,
@@ -693,7 +695,6 @@ def _add_parser_args(
693
695
) -> ArgumentParser :
694
696
subparsers : Any = None
695
697
alias_path_args : dict [str , str ] = {}
696
- alias_path_only_defaults : dict [str , Any ] = {}
697
698
# Ignore model default if the default is a model and not a subclass of the current model.
698
699
model_default = (
699
700
None
@@ -762,11 +763,9 @@ def _add_parser_args(
762
763
is_append_action = _annotation_contains_types (
763
764
field_info .annotation , (list , set , dict , Sequence , Mapping ), is_strip_annotated = True
764
765
)
765
- is_parser_submodel = bool ( sub_models ) and not is_append_action
766
+ is_parser_submodel = sub_models and not is_append_action
766
767
kwargs : dict [str , Any ] = {}
767
- kwargs ['default' ] = self ._get_cli_default_value (
768
- field_name , field_info , model_default , is_parser_submodel
769
- )
768
+ kwargs ['default' ] = CLI_SUPPRESS
770
769
kwargs ['help' ] = self ._help_format (field_name , field_info , model_default , is_model_suppressed )
771
770
kwargs ['metavar' ] = self ._metavar_format (field_info .annotation )
772
771
kwargs ['required' ] = (
@@ -825,14 +824,8 @@ def _add_parser_args(
825
824
self ._add_argument (
826
825
parser , * (f'{ flag_prefix [: len (name )]} { name } ' for name in arg_names ), ** kwargs
827
826
)
828
- elif kwargs ['default' ] != CLI_SUPPRESS :
829
- self ._update_alias_path_only_defaults (
830
- kwargs ['dest' ], kwargs ['default' ], field_info , alias_path_only_defaults
831
- )
832
827
833
- self ._add_parser_alias_paths (
834
- parser , alias_path_args , added_args , arg_prefix , subcommand_prefix , group , alias_path_only_defaults
835
- )
828
+ self ._add_parser_alias_paths (parser , alias_path_args , added_args , arg_prefix , subcommand_prefix , group )
836
829
return parser
837
830
838
831
def _check_kebab_name (self , name : str ) -> str :
@@ -859,6 +852,8 @@ def _convert_positional_arg(
859
852
) -> tuple [list [str ], str ]:
860
853
flag_prefix = ''
861
854
arg_names = [kwargs ['dest' ]]
855
+ kwargs ['default' ] = PydanticUndefined
856
+ kwargs ['metavar' ] = self ._check_kebab_name (preferred_alias .upper ())
862
857
863
858
# Note: CLI positional args are always strictly required at the CLI. Therefore, use field_info.is_required in
864
859
# conjunction with model_default instead of the derived kwargs['required'].
@@ -869,13 +864,6 @@ def _convert_positional_arg(
869
864
elif not is_required :
870
865
kwargs ['nargs' ] = '?'
871
866
872
- if self ._is_serialize_args :
873
- self ._serialize_positional_args [kwargs ['dest' ]] = kwargs ['default' ]
874
- kwargs ['nargs' ] = '*'
875
-
876
- kwargs ['default' ] = PydanticUndefined
877
- kwargs ['metavar' ] = self ._check_kebab_name (preferred_alias .upper ())
878
-
879
867
del kwargs ['dest' ]
880
868
del kwargs ['required' ]
881
869
return arg_names , flag_prefix
@@ -963,7 +951,7 @@ def _add_parser_submodels(
963
951
is_model_suppressed = self ._is_field_suppressed (field_info ) or is_model_suppressed
964
952
if is_model_suppressed :
965
953
model_group_kwargs ['description' ] = CLI_SUPPRESS
966
- if not self .cli_avoid_json and not self . _is_serialize_args :
954
+ if not self .cli_avoid_json :
967
955
added_args .append (arg_names [0 ])
968
956
kwargs ['nargs' ] = '?'
969
957
kwargs ['const' ] = '{}'
@@ -993,7 +981,6 @@ def _add_parser_alias_paths(
993
981
arg_prefix : str ,
994
982
subcommand_prefix : str ,
995
983
group : Any ,
996
- alias_path_only_defaults : dict [str , Any ],
997
984
) -> None :
998
985
if alias_path_args :
999
986
context = parser
@@ -1009,9 +996,9 @@ def _add_parser_alias_paths(
1009
996
else f'{ arg_prefix .replace (subcommand_prefix , "" , 1 )} { name } '
1010
997
)
1011
998
kwargs : dict [str , Any ] = {}
999
+ kwargs ['default' ] = CLI_SUPPRESS
1012
1000
kwargs ['help' ] = 'pydantic alias path'
1013
1001
kwargs ['dest' ] = f'{ arg_prefix } { name } '
1014
- kwargs ['default' ] = alias_path_only_defaults .get (kwargs ['dest' ], CLI_SUPPRESS )
1015
1002
if metavar == 'dict' or is_nested_alias_path :
1016
1003
kwargs ['metavar' ] = 'dict'
1017
1004
else :
@@ -1105,34 +1092,27 @@ def _is_field_suppressed(self, field_info: FieldInfo) -> bool:
1105
1092
_help = field_info .description if field_info .description else ''
1106
1093
return _help == CLI_SUPPRESS or CLI_SUPPRESS in field_info .metadata
1107
1094
1108
- def _get_cli_default_value (
1109
- self , field_name : str , field_info : FieldInfo , model_default : Any , is_parser_submodel : bool
1110
- ) -> Any :
1111
- if is_parser_submodel or not isinstance (self .root_parser , _CliInternalArgSerializer ):
1112
- return CLI_SUPPRESS
1113
-
1114
- return getattr (model_default , field_name , field_info .default )
1115
-
1116
- def _update_alias_path_only_defaults (
1117
- self , dest : str , default : Any , field_info : FieldInfo , alias_path_only_defaults : dict [str , Any ]
1118
- ) -> None :
1095
+ @classmethod
1096
+ def _update_alias_path_only_default (
1097
+ cls , arg_name : str , value : Any , field_info : FieldInfo , alias_path_only_defaults : dict [str , Any ]
1098
+ ) -> tuple [str , list [Any ] | dict [str , Any ]]:
1119
1099
alias_path : AliasPath = [
1120
1100
alias if isinstance (alias , AliasPath ) else cast (AliasPath , alias .choices [0 ])
1121
1101
for alias in (field_info .alias , field_info .validation_alias )
1122
1102
if isinstance (alias , (AliasPath , AliasChoices ))
1123
1103
][0 ]
1124
1104
1125
1105
alias_nested_paths : list [str ] = alias_path .path [1 :- 1 ] # type: ignore
1126
- if '.' in dest :
1127
- alias_nested_paths = dest .split ('.' ) + alias_nested_paths
1128
- dest = alias_nested_paths .pop (0 )
1106
+ if '.' in arg_name :
1107
+ alias_nested_paths = arg_name .split ('.' ) + alias_nested_paths
1108
+ arg_name = alias_nested_paths .pop (0 )
1129
1109
1130
1110
if not alias_nested_paths :
1131
- alias_path_only_defaults .setdefault (dest , [])
1132
- alias_default = alias_path_only_defaults [dest ]
1111
+ alias_path_only_defaults .setdefault (arg_name , [])
1112
+ alias_default = alias_path_only_defaults [arg_name ]
1133
1113
else :
1134
- alias_path_only_defaults .setdefault (dest , {})
1135
- current_path = alias_path_only_defaults [dest ]
1114
+ alias_path_only_defaults .setdefault (arg_name , {})
1115
+ current_path = alias_path_only_defaults [arg_name ]
1136
1116
1137
1117
for nested_path in alias_nested_paths [:- 1 ]:
1138
1118
current_path .setdefault (nested_path , {})
@@ -1142,22 +1122,84 @@ def _update_alias_path_only_defaults(
1142
1122
1143
1123
alias_path_index = cast (int , alias_path .path [- 1 ])
1144
1124
alias_default .extend (['' ] * max (alias_path_index + 1 - len (alias_default ), 0 ))
1145
- alias_default [alias_path_index ] = default
1125
+ alias_default [alias_path_index ] = value
1126
+ return arg_name , alias_path_only_defaults [arg_name ]
1127
+
1128
+ @classmethod
1129
+ def _serialized_args (cls , model : PydanticModel , model_config : Any , prefix : str = '' ) -> list [str ]:
1130
+ model_field_definitions : dict [str , Any ] = {}
1131
+ for field_name , field_info in _get_model_fields (type (model )).items ():
1132
+ model_default = getattr (model , field_name )
1133
+ if field_info .default == model_default :
1134
+ continue
1135
+ if _CliSubCommand in field_info .metadata and model_default is None :
1136
+ continue
1137
+ model_field_definitions [field_name ] = (field_info .annotation , field_info )
1138
+ cli_serialize_cls = create_model ('CliSerialize' , __config__ = model_config , ** model_field_definitions )
1139
+
1140
+ added_args : set [str ] = set ()
1141
+ alias_path_args : dict [str , str ] = {}
1142
+ alias_path_only_defaults : dict [str , Any ] = {}
1143
+ optional_args : list [str | list [Any ] | dict [str , Any ]] = []
1144
+ positional_args : list [str | list [Any ] | dict [str , Any ]] = []
1145
+ subcommand_args : list [str ] = []
1146
+ cli_settings = CliSettingsSource [Any ](cli_serialize_cls )
1147
+ for field_name , field_info in _get_model_fields (cli_serialize_cls ).items ():
1148
+ model_default = getattr (model , field_name )
1149
+ alias_names , is_alias_path_only = _get_alias_names (
1150
+ field_name , field_info , alias_path_args = alias_path_args , case_sensitive = cli_settings .case_sensitive
1151
+ )
1152
+ preferred_alias = alias_names [0 ]
1153
+ if _CliSubCommand in field_info .metadata :
1154
+ subcommand_args .append (cls ._check_kebab_name (cli_settings , preferred_alias ))
1155
+ subcommand_args += cls ._serialized_args (model_default , model_config )
1156
+ continue
1157
+ if is_model_class (type (model_default )) or is_pydantic_dataclass (type (model_default )):
1158
+ positional_args += cls ._serialized_args (
1159
+ model_default , model_config , prefix = f'{ prefix } { preferred_alias } .'
1160
+ )
1161
+ continue
1162
+
1163
+ arg_name = f'{ prefix } { cls ._check_kebab_name (cli_settings , preferred_alias )} '
1164
+ value : str | list [Any ] | dict [str , Any ] = (
1165
+ json .dumps (model_default ) if isinstance (model_default , (dict , list , set )) else str (model_default )
1166
+ )
1167
+
1168
+ if is_alias_path_only :
1169
+ # For alias path only, we wont know the complete value until we've finished parsing the entire class. In
1170
+ # this case, insert value as a non-string reference pointing to the relevant alias_path_only_defaults
1171
+ # entry and convert into completed string value later.
1172
+ arg_name , value = cls ._update_alias_path_only_default (
1173
+ arg_name , value , field_info , alias_path_only_defaults
1174
+ )
1175
+
1176
+ if arg_name in added_args :
1177
+ continue
1178
+ added_args .add (arg_name )
1179
+
1180
+ if _CliPositionalArg in field_info .metadata :
1181
+ if is_alias_path_only :
1182
+ positional_args .append (value )
1183
+ continue
1184
+ for value in model_default if isinstance (model_default , list ) else [model_default ]:
1185
+ value = json .dumps (value ) if isinstance (value , (dict , list , set )) else str (value )
1186
+ positional_args .append (value )
1187
+ continue
1146
1188
1147
- def _serialized_args (self ) -> list [str ]:
1148
- if not self ._is_serialize_args :
1149
- raise SettingsError ('Root parser is not _CliInternalArgSerializer' )
1189
+ flag_chars = f'{ cli_settings .cli_flag_prefix_char * min (len (arg_name ), 2 )} '
1190
+ kwargs = {'metavar' : cls ._metavar_format (cli_settings , field_info .annotation )}
1191
+ cls ._convert_bool_flag (cli_settings , kwargs , field_info , model_default )
1192
+ # Note: cls._convert_bool_flag will add action to kwargs if value is implicit bool flag
1193
+ if 'action' in kwargs and model_default is False :
1194
+ flag_chars += 'no-'
1150
1195
1151
- cli_args = []
1152
- for arg , values in self ._serialize_positional_args .items ():
1153
- for value in values if isinstance (values , list ) else [values ]:
1154
- value = json .dumps (value ) if isinstance (value , (dict , list , set )) else str (value )
1155
- cli_args .append (value )
1196
+ optional_args .append (f'{ flag_chars } { arg_name } ' )
1156
1197
1157
- for arg , value in self .env_vars .items ():
1158
- if arg not in self ._serialize_positional_args :
1159
- value = json .dumps (value ) if isinstance (value , (dict , list , set )) else str (value )
1160
- cli_args .append (f'{ self .cli_flag_prefix_char * min (len (arg ), 2 )} { arg } ' )
1161
- cli_args .append (value )
1198
+ # If implicit bool flag, do not add a value
1199
+ if 'action' not in kwargs :
1200
+ optional_args .append (value )
1162
1201
1163
- return cli_args
1202
+ serialized_args : list [str ] = []
1203
+ serialized_args += [json .dumps (value ) if not isinstance (value , str ) else value for value in optional_args ]
1204
+ serialized_args += [json .dumps (value ) if not isinstance (value , str ) else value for value in positional_args ]
1205
+ return serialized_args + subcommand_args
0 commit comments