diff --git a/pydantic_settings/main.py b/pydantic_settings/main.py index 57929bbc..4df690ce 100644 --- a/pydantic_settings/main.py +++ b/pydantic_settings/main.py @@ -8,12 +8,11 @@ from types import SimpleNamespace from typing import Any, ClassVar, TypeVar -from pydantic import ConfigDict, create_model +from pydantic import ConfigDict from pydantic._internal._config import config_keys from pydantic._internal._signature import _field_name_for_signature from pydantic._internal._utils import deep_update, is_model_class from pydantic.dataclasses import is_pydantic_dataclass -from pydantic.fields import FieldInfo from pydantic.main import BaseModel from .exceptions import SettingsError @@ -31,7 +30,6 @@ SecretsSettingsSource, get_subcommand, ) -from .sources.providers.cli import _CliInternalArgSerializer T = TypeVar('T') @@ -416,8 +414,8 @@ def _settings_build_values( ) sources = (cli_settings,) + sources # We ensure that if command line arguments haven't been parsed yet, we do so. - elif cli_parse_args and not custom_cli_sources[0].env_vars: - custom_cli_sources[0](args=cli_parse_args) + elif cli_parse_args not in (None, False) and not custom_cli_sources[0].env_vars: + custom_cli_sources[0](args=cli_parse_args) # type: ignore if sources: state: dict[str, Any] = {} @@ -647,14 +645,4 @@ def serialize(model: PydanticModel) -> list[str]: """ base_settings_cls = CliApp._get_base_settings_cls(type(model)) - model_field_definitions: dict[str, Any] = {} - for field_name, field_info in base_settings_cls.model_fields.items(): - model_field_definitions[field_name] = ( - field_info.annotation, - FieldInfo.merge_field_infos(field_info, default=getattr(model, field_name)), - ) - - cli_serialize_cls = create_model('CliSerialize', __base__=base_settings_cls, **model_field_definitions) - return CliSettingsSource[Any]( - cli_serialize_cls, cli_parse_args=[], root_parser=_CliInternalArgSerializer() - )._serialized_args() + return CliSettingsSource._serialized_args(model, base_settings_cls.model_config) diff --git a/pydantic_settings/sources/providers/cli.py b/pydantic_settings/sources/providers/cli.py index fefa103c..93e0ba42 100644 --- a/pydantic_settings/sources/providers/cli.py +++ b/pydantic_settings/sources/providers/cli.py @@ -35,7 +35,7 @@ ) import typing_extensions -from pydantic import AliasChoices, AliasPath, BaseModel, Field +from pydantic import AliasChoices, AliasPath, BaseModel, Field, create_model from pydantic._internal._repr import Representation from pydantic._internal._utils import is_model_class from pydantic.dataclasses import is_pydantic_dataclass @@ -47,7 +47,15 @@ from ...exceptions import SettingsError from ...utils import _lenient_issubclass, _WithArgsTypes -from ..types import NoDecode, _CliExplicitFlag, _CliImplicitFlag, _CliPositionalArg, _CliSubCommand, _CliUnknownArgs +from ..types import ( + NoDecode, + PydanticModel, + _CliExplicitFlag, + _CliImplicitFlag, + _CliPositionalArg, + _CliSubCommand, + _CliUnknownArgs, +) from ..utils import ( _annotation_contains_types, _annotation_enum_val_to_name, @@ -74,10 +82,6 @@ def error(self, message: str) -> NoReturn: super().error(message) -class _CliInternalArgSerializer(_CliInternalArgParser): - pass - - class CliMutuallyExclusiveGroup(BaseModel): pass @@ -666,8 +670,6 @@ def _parse_known_args(*args: Any, **kwargs: Any) -> Namespace: self._formatter_class = formatter_class self._cli_dict_args: dict[str, type[Any] | None] = {} self._cli_subcommands: defaultdict[str, dict[str, str]] = defaultdict(dict) - self._is_serialize_args = isinstance(root_parser, _CliInternalArgSerializer) - self._serialize_positional_args: dict[str, Any] = {} self._add_parser_args( parser=self.root_parser, model=self.settings_cls, @@ -693,7 +695,6 @@ def _add_parser_args( ) -> ArgumentParser: subparsers: Any = None alias_path_args: dict[str, str] = {} - alias_path_only_defaults: dict[str, Any] = {} # Ignore model default if the default is a model and not a subclass of the current model. model_default = ( None @@ -762,11 +763,9 @@ def _add_parser_args( is_append_action = _annotation_contains_types( field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True ) - is_parser_submodel = bool(sub_models) and not is_append_action + is_parser_submodel = sub_models and not is_append_action kwargs: dict[str, Any] = {} - kwargs['default'] = self._get_cli_default_value( - field_name, field_info, model_default, is_parser_submodel - ) + kwargs['default'] = CLI_SUPPRESS kwargs['help'] = self._help_format(field_name, field_info, model_default, is_model_suppressed) kwargs['metavar'] = self._metavar_format(field_info.annotation) kwargs['required'] = ( @@ -825,14 +824,8 @@ def _add_parser_args( self._add_argument( parser, *(f'{flag_prefix[: len(name)]}{name}' for name in arg_names), **kwargs ) - elif kwargs['default'] != CLI_SUPPRESS: - self._update_alias_path_only_defaults( - kwargs['dest'], kwargs['default'], field_info, alias_path_only_defaults - ) - self._add_parser_alias_paths( - parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group, alias_path_only_defaults - ) + self._add_parser_alias_paths(parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group) return parser def _check_kebab_name(self, name: str) -> str: @@ -859,6 +852,8 @@ def _convert_positional_arg( ) -> tuple[list[str], str]: flag_prefix = '' arg_names = [kwargs['dest']] + kwargs['default'] = PydanticUndefined + kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper()) # Note: CLI positional args are always strictly required at the CLI. Therefore, use field_info.is_required in # conjunction with model_default instead of the derived kwargs['required']. @@ -869,13 +864,6 @@ def _convert_positional_arg( elif not is_required: kwargs['nargs'] = '?' - if self._is_serialize_args: - self._serialize_positional_args[kwargs['dest']] = kwargs['default'] - kwargs['nargs'] = '*' - - kwargs['default'] = PydanticUndefined - kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper()) - del kwargs['dest'] del kwargs['required'] return arg_names, flag_prefix @@ -963,7 +951,7 @@ def _add_parser_submodels( is_model_suppressed = self._is_field_suppressed(field_info) or is_model_suppressed if is_model_suppressed: model_group_kwargs['description'] = CLI_SUPPRESS - if not self.cli_avoid_json and not self._is_serialize_args: + if not self.cli_avoid_json: added_args.append(arg_names[0]) kwargs['nargs'] = '?' kwargs['const'] = '{}' @@ -993,7 +981,6 @@ def _add_parser_alias_paths( arg_prefix: str, subcommand_prefix: str, group: Any, - alias_path_only_defaults: dict[str, Any], ) -> None: if alias_path_args: context = parser @@ -1009,9 +996,9 @@ def _add_parser_alias_paths( else f'{arg_prefix.replace(subcommand_prefix, "", 1)}{name}' ) kwargs: dict[str, Any] = {} + kwargs['default'] = CLI_SUPPRESS kwargs['help'] = 'pydantic alias path' kwargs['dest'] = f'{arg_prefix}{name}' - kwargs['default'] = alias_path_only_defaults.get(kwargs['dest'], CLI_SUPPRESS) if metavar == 'dict' or is_nested_alias_path: kwargs['metavar'] = 'dict' else: @@ -1105,17 +1092,10 @@ def _is_field_suppressed(self, field_info: FieldInfo) -> bool: _help = field_info.description if field_info.description else '' return _help == CLI_SUPPRESS or CLI_SUPPRESS in field_info.metadata - def _get_cli_default_value( - self, field_name: str, field_info: FieldInfo, model_default: Any, is_parser_submodel: bool - ) -> Any: - if is_parser_submodel or not isinstance(self.root_parser, _CliInternalArgSerializer): - return CLI_SUPPRESS - - return getattr(model_default, field_name, field_info.default) - - def _update_alias_path_only_defaults( - self, dest: str, default: Any, field_info: FieldInfo, alias_path_only_defaults: dict[str, Any] - ) -> None: + @classmethod + def _update_alias_path_only_default( + cls, arg_name: str, value: Any, field_info: FieldInfo, alias_path_only_defaults: dict[str, Any] + ) -> tuple[str, list[Any] | dict[str, Any]]: alias_path: AliasPath = [ alias if isinstance(alias, AliasPath) else cast(AliasPath, alias.choices[0]) for alias in (field_info.alias, field_info.validation_alias) @@ -1123,16 +1103,16 @@ def _update_alias_path_only_defaults( ][0] alias_nested_paths: list[str] = alias_path.path[1:-1] # type: ignore - if '.' in dest: - alias_nested_paths = dest.split('.') + alias_nested_paths - dest = alias_nested_paths.pop(0) + if '.' in arg_name: + alias_nested_paths = arg_name.split('.') + alias_nested_paths + arg_name = alias_nested_paths.pop(0) if not alias_nested_paths: - alias_path_only_defaults.setdefault(dest, []) - alias_default = alias_path_only_defaults[dest] + alias_path_only_defaults.setdefault(arg_name, []) + alias_default = alias_path_only_defaults[arg_name] else: - alias_path_only_defaults.setdefault(dest, {}) - current_path = alias_path_only_defaults[dest] + alias_path_only_defaults.setdefault(arg_name, {}) + current_path = alias_path_only_defaults[arg_name] for nested_path in alias_nested_paths[:-1]: current_path.setdefault(nested_path, {}) @@ -1142,22 +1122,84 @@ def _update_alias_path_only_defaults( alias_path_index = cast(int, alias_path.path[-1]) alias_default.extend([''] * max(alias_path_index + 1 - len(alias_default), 0)) - alias_default[alias_path_index] = default + alias_default[alias_path_index] = value + return arg_name, alias_path_only_defaults[arg_name] + + @classmethod + def _serialized_args(cls, model: PydanticModel, model_config: Any, prefix: str = '') -> list[str]: + model_field_definitions: dict[str, Any] = {} + for field_name, field_info in _get_model_fields(type(model)).items(): + model_default = getattr(model, field_name) + if field_info.default == model_default: + continue + if _CliSubCommand in field_info.metadata and model_default is None: + continue + model_field_definitions[field_name] = (field_info.annotation, field_info) + cli_serialize_cls = create_model('CliSerialize', __config__=model_config, **model_field_definitions) + + added_args: set[str] = set() + alias_path_args: dict[str, str] = {} + alias_path_only_defaults: dict[str, Any] = {} + optional_args: list[str | list[Any] | dict[str, Any]] = [] + positional_args: list[str | list[Any] | dict[str, Any]] = [] + subcommand_args: list[str] = [] + cli_settings = CliSettingsSource[Any](cli_serialize_cls) + for field_name, field_info in _get_model_fields(cli_serialize_cls).items(): + model_default = getattr(model, field_name) + alias_names, is_alias_path_only = _get_alias_names( + field_name, field_info, alias_path_args=alias_path_args, case_sensitive=cli_settings.case_sensitive + ) + preferred_alias = alias_names[0] + if _CliSubCommand in field_info.metadata: + subcommand_args.append(cls._check_kebab_name(cli_settings, preferred_alias)) + subcommand_args += cls._serialized_args(model_default, model_config) + continue + if is_model_class(type(model_default)) or is_pydantic_dataclass(type(model_default)): + positional_args += cls._serialized_args( + model_default, model_config, prefix=f'{prefix}{preferred_alias}.' + ) + continue + + arg_name = f'{prefix}{cls._check_kebab_name(cli_settings, preferred_alias)}' + value: str | list[Any] | dict[str, Any] = ( + json.dumps(model_default) if isinstance(model_default, (dict, list, set)) else str(model_default) + ) + + if is_alias_path_only: + # For alias path only, we wont know the complete value until we've finished parsing the entire class. In + # this case, insert value as a non-string reference pointing to the relevant alias_path_only_defaults + # entry and convert into completed string value later. + arg_name, value = cls._update_alias_path_only_default( + arg_name, value, field_info, alias_path_only_defaults + ) + + if arg_name in added_args: + continue + added_args.add(arg_name) + + if _CliPositionalArg in field_info.metadata: + if is_alias_path_only: + positional_args.append(value) + continue + for value in model_default if isinstance(model_default, list) else [model_default]: + value = json.dumps(value) if isinstance(value, (dict, list, set)) else str(value) + positional_args.append(value) + continue - def _serialized_args(self) -> list[str]: - if not self._is_serialize_args: - raise SettingsError('Root parser is not _CliInternalArgSerializer') + flag_chars = f'{cli_settings.cli_flag_prefix_char * min(len(arg_name), 2)}' + kwargs = {'metavar': cls._metavar_format(cli_settings, field_info.annotation)} + cls._convert_bool_flag(cli_settings, kwargs, field_info, model_default) + # Note: cls._convert_bool_flag will add action to kwargs if value is implicit bool flag + if 'action' in kwargs and model_default is False: + flag_chars += 'no-' - cli_args = [] - for arg, values in self._serialize_positional_args.items(): - for value in values if isinstance(values, list) else [values]: - value = json.dumps(value) if isinstance(value, (dict, list, set)) else str(value) - cli_args.append(value) + optional_args.append(f'{flag_chars}{arg_name}') - for arg, value in self.env_vars.items(): - if arg not in self._serialize_positional_args: - value = json.dumps(value) if isinstance(value, (dict, list, set)) else str(value) - cli_args.append(f'{self.cli_flag_prefix_char * min(len(arg), 2)}{arg}') - cli_args.append(value) + # If implicit bool flag, do not add a value + if 'action' not in kwargs: + optional_args.append(value) - return cli_args + serialized_args: list[str] = [] + serialized_args += [json.dumps(value) if not isinstance(value, str) else value for value in optional_args] + serialized_args += [json.dumps(value) if not isinstance(value, str) else value for value in positional_args] + return serialized_args + subcommand_args diff --git a/tests/test_source_cli.py b/tests/test_source_cli.py index 68d3693f..5f8cc37a 100644 --- a/tests/test_source_cli.py +++ b/tests/test_source_cli.py @@ -301,14 +301,14 @@ class Cfg(BaseSettings, cli_avoid_json=avoid_json): assert serialized_cli_args == [ '-a', 'a', + '--path1', + '["", "b1"]', '-b', 'b', - '--str', - 'str', - '--path1', - '["","b1"]', '--path2', '{"deep": ["", "b2"]}', + '--str', + 'str', ] assert CliApp.run(Cfg, cli_args=serialized_cli_args).model_dump() == cfg.model_dump() @@ -352,12 +352,12 @@ class Cfg(BaseSettings, cli_avoid_json=avoid_json): assert serialized_cli_args == [ '--nest.a', 'a', + '--nest', + '{"path1": ["", "b1"], "path2": {"deep": ["", "b2"]}}', '--nest.b', 'b', '--nest.str', 'str', - '--nest', - '{"path1": ["", "b1"], "path2": {"deep": ["", "b2"]}}', ] assert CliApp.run(Cfg, cli_args=serialized_cli_args).model_dump() == cfg.model_dump() @@ -1539,8 +1539,19 @@ class ImplicitSettings(BaseSettings, cli_implicit_flags=True, cli_enforce_requir 'implicit_opt': False, } - assert CliApp.run(ExplicitSettings, cli_args=['--explicit_req=True', '--implicit_req']).model_dump() == expected - assert CliApp.run(ImplicitSettings, cli_args=['--explicit_req=True', '--implicit_req']).model_dump() == expected + explicit_settings = CliApp.run(ExplicitSettings, cli_args=['--explicit_req=True', '--implicit_req']) + assert explicit_settings.model_dump() == expected + + implicit_settings = CliApp.run(ImplicitSettings, cli_args=['--explicit_req=True', '--implicit_req']) + assert implicit_settings.model_dump() == expected + + serialized_args = CliApp.serialize(explicit_settings) + assert serialized_args == ['--explicit_req', 'True', '--implicit_req'] + assert CliApp.run(ExplicitSettings, cli_args=serialized_args).model_dump() == expected + + serialized_args = CliApp.serialize(implicit_settings) + assert serialized_args == ['--explicit_req', 'True', '--implicit_req'] + assert CliApp.run(ImplicitSettings, cli_args=serialized_args).model_dump() == expected def test_cli_avoid_json(capsys, monkeypatch): @@ -1974,9 +1985,6 @@ class Cfg(BaseSettings): with pytest.raises(SettingsError, match='CLI settings source prefix is invalid: 123'): CliSettingsSource(Cfg, cli_prefix='123') - with pytest.raises(SettingsError, match='Root parser is not _CliInternalArgSerializer'): - CliSettingsSource[Any](Cfg, cli_parse_args=[])._serialized_args() - class Food(BaseModel): fruit: FruitsEnum = FruitsEnum.kiwi @@ -2432,7 +2440,7 @@ class Root(BaseModel): root_subcmd: CliSubCommand[SubModel] root_arg: str - assert CliApp.run( + root = CliApp.run( Root, cli_args=[ '--root-arg=hi', @@ -2442,7 +2450,8 @@ class Root(BaseModel): 'hey', '--deep-arg=bye', ], - ).model_dump() == { + ) + assert root.model_dump() == { 'root_arg': 'hi', 'root_subcmd': { 'sub_arg': 'hello', @@ -2450,6 +2459,19 @@ class Root(BaseModel): }, } + serialized_cli_args = CliApp.serialize(root) + assert serialized_cli_args == [ + '--root-arg', + 'hi', + 'root-subcmd', + '--sub-arg', + 'hello', + 'sub-subcmd', + '--deep-arg', + 'bye', + 'hey', + ] + with monkeypatch.context() as m: m.setattr(sys, 'argv', ['example.py', '--help']) with pytest.raises(SystemExit): @@ -2626,7 +2648,7 @@ class Settings(BaseSettings): } -def test_cli_serialize_positional_args(env): +def test_cli_serialize_positional_args(): class Nested(BaseModel): deep: CliPositionalArg[int] @@ -2680,3 +2702,44 @@ class Cfg(BaseSettings): assert parsed_args.extra == 4 # With parsed arguments passed to CliApp.run, the parser should not need to be called again. assert CliApp.run(Cfg, cli_args=parsed_args, cli_settings_source=cli_settings).model_dump() == {'pet': 'dog'} + + +def test_cli_serialize_non_default_values(): + class Cfg(BaseSettings): + default_val: int = 123 + non_default_val: int + + cfg = Cfg(non_default_val=456) + assert cfg.model_dump() == {'default_val': 123, 'non_default_val': 456} + + serialized_cli_args = CliApp.serialize(cfg) + assert serialized_cli_args == ['--non_default_val', '456'] + + assert CliApp.run(Cfg, cli_args=serialized_cli_args).model_dump() == cfg.model_dump() + + +def test_cli_serialize_ordering(): + class NestedCfg(BaseSettings): + positional: CliPositionalArg[str] + optional: int + + class Cfg(BaseSettings): + command: CliSubCommand[NestedCfg] + positional: CliPositionalArg[str] + optional: int + + cfg = Cfg(optional=0, positional='pos_1', command=NestedCfg(optional=2, positional='pos_3')) + assert cfg.model_dump() == {'command': {'optional': 2, 'positional': 'pos_3'}, 'optional': 0, 'positional': 'pos_1'} + + serialized_cli_args = CliApp.serialize(cfg) + assert serialized_cli_args == [ + '--optional', + '0', + 'pos_1', + 'command', + '--optional', + '2', + 'pos_3', + ] + + assert CliApp.run(Cfg, cli_args=serialized_cli_args).model_dump() == cfg.model_dump()