Skip to content

Commit be11824

Browse files
authored
CLI Serialization Fixes (#649)
1 parent 3a2b7b4 commit be11824

File tree

3 files changed

+186
-93
lines changed

3 files changed

+186
-93
lines changed

pydantic_settings/main.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
from types import SimpleNamespace
99
from typing import Any, ClassVar, TypeVar
1010

11-
from pydantic import ConfigDict, create_model
11+
from pydantic import ConfigDict
1212
from pydantic._internal._config import config_keys
1313
from pydantic._internal._signature import _field_name_for_signature
1414
from pydantic._internal._utils import deep_update, is_model_class
1515
from pydantic.dataclasses import is_pydantic_dataclass
16-
from pydantic.fields import FieldInfo
1716
from pydantic.main import BaseModel
1817

1918
from .exceptions import SettingsError
@@ -31,7 +30,6 @@
3130
SecretsSettingsSource,
3231
get_subcommand,
3332
)
34-
from .sources.providers.cli import _CliInternalArgSerializer
3533

3634
T = TypeVar('T')
3735

@@ -416,8 +414,8 @@ def _settings_build_values(
416414
)
417415
sources = (cli_settings,) + sources
418416
# We ensure that if command line arguments haven't been parsed yet, we do so.
419-
elif cli_parse_args and not custom_cli_sources[0].env_vars:
420-
custom_cli_sources[0](args=cli_parse_args)
417+
elif cli_parse_args not in (None, False) and not custom_cli_sources[0].env_vars:
418+
custom_cli_sources[0](args=cli_parse_args) # type: ignore
421419

422420
if sources:
423421
state: dict[str, Any] = {}
@@ -647,14 +645,4 @@ def serialize(model: PydanticModel) -> list[str]:
647645
"""
648646

649647
base_settings_cls = CliApp._get_base_settings_cls(type(model))
650-
model_field_definitions: dict[str, Any] = {}
651-
for field_name, field_info in base_settings_cls.model_fields.items():
652-
model_field_definitions[field_name] = (
653-
field_info.annotation,
654-
FieldInfo.merge_field_infos(field_info, default=getattr(model, field_name)),
655-
)
656-
657-
cli_serialize_cls = create_model('CliSerialize', __base__=base_settings_cls, **model_field_definitions)
658-
return CliSettingsSource[Any](
659-
cli_serialize_cls, cli_parse_args=[], root_parser=_CliInternalArgSerializer()
660-
)._serialized_args()
648+
return CliSettingsSource._serialized_args(model, base_settings_cls.model_config)

pydantic_settings/sources/providers/cli.py

Lines changed: 105 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636

3737
import typing_extensions
38-
from pydantic import AliasChoices, AliasPath, BaseModel, Field
38+
from pydantic import AliasChoices, AliasPath, BaseModel, Field, create_model
3939
from pydantic._internal._repr import Representation
4040
from pydantic._internal._utils import is_model_class
4141
from pydantic.dataclasses import is_pydantic_dataclass
@@ -47,7 +47,15 @@
4747

4848
from ...exceptions import SettingsError
4949
from ...utils import _lenient_issubclass, _WithArgsTypes
50-
from ..types import NoDecode, _CliExplicitFlag, _CliImplicitFlag, _CliPositionalArg, _CliSubCommand, _CliUnknownArgs
50+
from ..types import (
51+
NoDecode,
52+
PydanticModel,
53+
_CliExplicitFlag,
54+
_CliImplicitFlag,
55+
_CliPositionalArg,
56+
_CliSubCommand,
57+
_CliUnknownArgs,
58+
)
5159
from ..utils import (
5260
_annotation_contains_types,
5361
_annotation_enum_val_to_name,
@@ -74,10 +82,6 @@ def error(self, message: str) -> NoReturn:
7482
super().error(message)
7583

7684

77-
class _CliInternalArgSerializer(_CliInternalArgParser):
78-
pass
79-
80-
8185
class CliMutuallyExclusiveGroup(BaseModel):
8286
pass
8387

@@ -666,8 +670,6 @@ def _parse_known_args(*args: Any, **kwargs: Any) -> Namespace:
666670
self._formatter_class = formatter_class
667671
self._cli_dict_args: dict[str, type[Any] | None] = {}
668672
self._cli_subcommands: defaultdict[str, dict[str, str]] = defaultdict(dict)
669-
self._is_serialize_args = isinstance(root_parser, _CliInternalArgSerializer)
670-
self._serialize_positional_args: dict[str, Any] = {}
671673
self._add_parser_args(
672674
parser=self.root_parser,
673675
model=self.settings_cls,
@@ -693,7 +695,6 @@ def _add_parser_args(
693695
) -> ArgumentParser:
694696
subparsers: Any = None
695697
alias_path_args: dict[str, str] = {}
696-
alias_path_only_defaults: dict[str, Any] = {}
697698
# Ignore model default if the default is a model and not a subclass of the current model.
698699
model_default = (
699700
None
@@ -762,11 +763,9 @@ def _add_parser_args(
762763
is_append_action = _annotation_contains_types(
763764
field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True
764765
)
765-
is_parser_submodel = bool(sub_models) and not is_append_action
766+
is_parser_submodel = sub_models and not is_append_action
766767
kwargs: dict[str, Any] = {}
767-
kwargs['default'] = self._get_cli_default_value(
768-
field_name, field_info, model_default, is_parser_submodel
769-
)
768+
kwargs['default'] = CLI_SUPPRESS
770769
kwargs['help'] = self._help_format(field_name, field_info, model_default, is_model_suppressed)
771770
kwargs['metavar'] = self._metavar_format(field_info.annotation)
772771
kwargs['required'] = (
@@ -825,14 +824,8 @@ def _add_parser_args(
825824
self._add_argument(
826825
parser, *(f'{flag_prefix[: len(name)]}{name}' for name in arg_names), **kwargs
827826
)
828-
elif kwargs['default'] != CLI_SUPPRESS:
829-
self._update_alias_path_only_defaults(
830-
kwargs['dest'], kwargs['default'], field_info, alias_path_only_defaults
831-
)
832827

833-
self._add_parser_alias_paths(
834-
parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group, alias_path_only_defaults
835-
)
828+
self._add_parser_alias_paths(parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group)
836829
return parser
837830

838831
def _check_kebab_name(self, name: str) -> str:
@@ -859,6 +852,8 @@ def _convert_positional_arg(
859852
) -> tuple[list[str], str]:
860853
flag_prefix = ''
861854
arg_names = [kwargs['dest']]
855+
kwargs['default'] = PydanticUndefined
856+
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())
862857

863858
# Note: CLI positional args are always strictly required at the CLI. Therefore, use field_info.is_required in
864859
# conjunction with model_default instead of the derived kwargs['required'].
@@ -869,13 +864,6 @@ def _convert_positional_arg(
869864
elif not is_required:
870865
kwargs['nargs'] = '?'
871866

872-
if self._is_serialize_args:
873-
self._serialize_positional_args[kwargs['dest']] = kwargs['default']
874-
kwargs['nargs'] = '*'
875-
876-
kwargs['default'] = PydanticUndefined
877-
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())
878-
879867
del kwargs['dest']
880868
del kwargs['required']
881869
return arg_names, flag_prefix
@@ -963,7 +951,7 @@ def _add_parser_submodels(
963951
is_model_suppressed = self._is_field_suppressed(field_info) or is_model_suppressed
964952
if is_model_suppressed:
965953
model_group_kwargs['description'] = CLI_SUPPRESS
966-
if not self.cli_avoid_json and not self._is_serialize_args:
954+
if not self.cli_avoid_json:
967955
added_args.append(arg_names[0])
968956
kwargs['nargs'] = '?'
969957
kwargs['const'] = '{}'
@@ -993,7 +981,6 @@ def _add_parser_alias_paths(
993981
arg_prefix: str,
994982
subcommand_prefix: str,
995983
group: Any,
996-
alias_path_only_defaults: dict[str, Any],
997984
) -> None:
998985
if alias_path_args:
999986
context = parser
@@ -1009,9 +996,9 @@ def _add_parser_alias_paths(
1009996
else f'{arg_prefix.replace(subcommand_prefix, "", 1)}{name}'
1010997
)
1011998
kwargs: dict[str, Any] = {}
999+
kwargs['default'] = CLI_SUPPRESS
10121000
kwargs['help'] = 'pydantic alias path'
10131001
kwargs['dest'] = f'{arg_prefix}{name}'
1014-
kwargs['default'] = alias_path_only_defaults.get(kwargs['dest'], CLI_SUPPRESS)
10151002
if metavar == 'dict' or is_nested_alias_path:
10161003
kwargs['metavar'] = 'dict'
10171004
else:
@@ -1105,34 +1092,27 @@ def _is_field_suppressed(self, field_info: FieldInfo) -> bool:
11051092
_help = field_info.description if field_info.description else ''
11061093
return _help == CLI_SUPPRESS or CLI_SUPPRESS in field_info.metadata
11071094

1108-
def _get_cli_default_value(
1109-
self, field_name: str, field_info: FieldInfo, model_default: Any, is_parser_submodel: bool
1110-
) -> Any:
1111-
if is_parser_submodel or not isinstance(self.root_parser, _CliInternalArgSerializer):
1112-
return CLI_SUPPRESS
1113-
1114-
return getattr(model_default, field_name, field_info.default)
1115-
1116-
def _update_alias_path_only_defaults(
1117-
self, dest: str, default: Any, field_info: FieldInfo, alias_path_only_defaults: dict[str, Any]
1118-
) -> None:
1095+
@classmethod
1096+
def _update_alias_path_only_default(
1097+
cls, arg_name: str, value: Any, field_info: FieldInfo, alias_path_only_defaults: dict[str, Any]
1098+
) -> tuple[str, list[Any] | dict[str, Any]]:
11191099
alias_path: AliasPath = [
11201100
alias if isinstance(alias, AliasPath) else cast(AliasPath, alias.choices[0])
11211101
for alias in (field_info.alias, field_info.validation_alias)
11221102
if isinstance(alias, (AliasPath, AliasChoices))
11231103
][0]
11241104

11251105
alias_nested_paths: list[str] = alias_path.path[1:-1] # type: ignore
1126-
if '.' in dest:
1127-
alias_nested_paths = dest.split('.') + alias_nested_paths
1128-
dest = alias_nested_paths.pop(0)
1106+
if '.' in arg_name:
1107+
alias_nested_paths = arg_name.split('.') + alias_nested_paths
1108+
arg_name = alias_nested_paths.pop(0)
11291109

11301110
if not alias_nested_paths:
1131-
alias_path_only_defaults.setdefault(dest, [])
1132-
alias_default = alias_path_only_defaults[dest]
1111+
alias_path_only_defaults.setdefault(arg_name, [])
1112+
alias_default = alias_path_only_defaults[arg_name]
11331113
else:
1134-
alias_path_only_defaults.setdefault(dest, {})
1135-
current_path = alias_path_only_defaults[dest]
1114+
alias_path_only_defaults.setdefault(arg_name, {})
1115+
current_path = alias_path_only_defaults[arg_name]
11361116

11371117
for nested_path in alias_nested_paths[:-1]:
11381118
current_path.setdefault(nested_path, {})
@@ -1142,22 +1122,84 @@ def _update_alias_path_only_defaults(
11421122

11431123
alias_path_index = cast(int, alias_path.path[-1])
11441124
alias_default.extend([''] * max(alias_path_index + 1 - len(alias_default), 0))
1145-
alias_default[alias_path_index] = default
1125+
alias_default[alias_path_index] = value
1126+
return arg_name, alias_path_only_defaults[arg_name]
1127+
1128+
@classmethod
1129+
def _serialized_args(cls, model: PydanticModel, model_config: Any, prefix: str = '') -> list[str]:
1130+
model_field_definitions: dict[str, Any] = {}
1131+
for field_name, field_info in _get_model_fields(type(model)).items():
1132+
model_default = getattr(model, field_name)
1133+
if field_info.default == model_default:
1134+
continue
1135+
if _CliSubCommand in field_info.metadata and model_default is None:
1136+
continue
1137+
model_field_definitions[field_name] = (field_info.annotation, field_info)
1138+
cli_serialize_cls = create_model('CliSerialize', __config__=model_config, **model_field_definitions)
1139+
1140+
added_args: set[str] = set()
1141+
alias_path_args: dict[str, str] = {}
1142+
alias_path_only_defaults: dict[str, Any] = {}
1143+
optional_args: list[str | list[Any] | dict[str, Any]] = []
1144+
positional_args: list[str | list[Any] | dict[str, Any]] = []
1145+
subcommand_args: list[str] = []
1146+
cli_settings = CliSettingsSource[Any](cli_serialize_cls)
1147+
for field_name, field_info in _get_model_fields(cli_serialize_cls).items():
1148+
model_default = getattr(model, field_name)
1149+
alias_names, is_alias_path_only = _get_alias_names(
1150+
field_name, field_info, alias_path_args=alias_path_args, case_sensitive=cli_settings.case_sensitive
1151+
)
1152+
preferred_alias = alias_names[0]
1153+
if _CliSubCommand in field_info.metadata:
1154+
subcommand_args.append(cls._check_kebab_name(cli_settings, preferred_alias))
1155+
subcommand_args += cls._serialized_args(model_default, model_config)
1156+
continue
1157+
if is_model_class(type(model_default)) or is_pydantic_dataclass(type(model_default)):
1158+
positional_args += cls._serialized_args(
1159+
model_default, model_config, prefix=f'{prefix}{preferred_alias}.'
1160+
)
1161+
continue
1162+
1163+
arg_name = f'{prefix}{cls._check_kebab_name(cli_settings, preferred_alias)}'
1164+
value: str | list[Any] | dict[str, Any] = (
1165+
json.dumps(model_default) if isinstance(model_default, (dict, list, set)) else str(model_default)
1166+
)
1167+
1168+
if is_alias_path_only:
1169+
# For alias path only, we wont know the complete value until we've finished parsing the entire class. In
1170+
# this case, insert value as a non-string reference pointing to the relevant alias_path_only_defaults
1171+
# entry and convert into completed string value later.
1172+
arg_name, value = cls._update_alias_path_only_default(
1173+
arg_name, value, field_info, alias_path_only_defaults
1174+
)
1175+
1176+
if arg_name in added_args:
1177+
continue
1178+
added_args.add(arg_name)
1179+
1180+
if _CliPositionalArg in field_info.metadata:
1181+
if is_alias_path_only:
1182+
positional_args.append(value)
1183+
continue
1184+
for value in model_default if isinstance(model_default, list) else [model_default]:
1185+
value = json.dumps(value) if isinstance(value, (dict, list, set)) else str(value)
1186+
positional_args.append(value)
1187+
continue
11461188

1147-
def _serialized_args(self) -> list[str]:
1148-
if not self._is_serialize_args:
1149-
raise SettingsError('Root parser is not _CliInternalArgSerializer')
1189+
flag_chars = f'{cli_settings.cli_flag_prefix_char * min(len(arg_name), 2)}'
1190+
kwargs = {'metavar': cls._metavar_format(cli_settings, field_info.annotation)}
1191+
cls._convert_bool_flag(cli_settings, kwargs, field_info, model_default)
1192+
# Note: cls._convert_bool_flag will add action to kwargs if value is implicit bool flag
1193+
if 'action' in kwargs and model_default is False:
1194+
flag_chars += 'no-'
11501195

1151-
cli_args = []
1152-
for arg, values in self._serialize_positional_args.items():
1153-
for value in values if isinstance(values, list) else [values]:
1154-
value = json.dumps(value) if isinstance(value, (dict, list, set)) else str(value)
1155-
cli_args.append(value)
1196+
optional_args.append(f'{flag_chars}{arg_name}')
11561197

1157-
for arg, value in self.env_vars.items():
1158-
if arg not in self._serialize_positional_args:
1159-
value = json.dumps(value) if isinstance(value, (dict, list, set)) else str(value)
1160-
cli_args.append(f'{self.cli_flag_prefix_char * min(len(arg), 2)}{arg}')
1161-
cli_args.append(value)
1198+
# If implicit bool flag, do not add a value
1199+
if 'action' not in kwargs:
1200+
optional_args.append(value)
11621201

1163-
return cli_args
1202+
serialized_args: list[str] = []
1203+
serialized_args += [json.dumps(value) if not isinstance(value, str) else value for value in optional_args]
1204+
serialized_args += [json.dumps(value) if not isinstance(value, str) else value for value in positional_args]
1205+
return serialized_args + subcommand_args

0 commit comments

Comments
 (0)