Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,29 @@ CliApp.run(Git, cli_args=['clone', 'repo', 'dir']).model_dump() == {

When executing a subcommand with an asynchronous cli_cmd, Pydantic settings automatically detects whether the current thread already has an active event loop. If so, the async command is run in a fresh thread to avoid conflicts. Otherwise, it uses asyncio.run() in the current thread. This handling ensures your asynchronous subcommands "just work" without additional manual setup.

### Serializing CLI Arguments

An instantiated Pydantic model can be serialized into its CLI arguments using the `CliApp.serialize` method.

```py
from pydantic import BaseModel

from pydantic_settings import CliApp


class Nested(BaseModel):
that: int


class Settings(BaseModel):
this: str
nested: Nested


print(CliApp.serialize(Settings(this='hello', nested=Nested(that=123))))
#> ['--this', 'hello', '--nested.that', '123']
```

### Mutually Exclusive Groups

CLI mutually exclusive groups can be created by inheriting from the `CliMutuallyExclusiveGroup` class.
Expand Down
66 changes: 50 additions & 16 deletions pydantic_settings/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from types import SimpleNamespace
from typing import Any, ClassVar, TypeVar

from pydantic import ConfigDict
from pydantic import ConfigDict, create_model
from pydantic._internal._config import config_keys
from pydantic._internal._signature import _field_name_for_signature
from pydantic._internal._utils import deep_update, is_model_class
from pydantic.dataclasses import is_pydantic_dataclass
from pydantic.fields import FieldInfo
from pydantic.main import BaseModel

from .exceptions import SettingsError
Expand All @@ -30,6 +31,7 @@
SecretsSettingsSource,
get_subcommand,
)
from .sources.providers.cli import _CliInternalArgSerializer

T = TypeVar('T')

Expand Down Expand Up @@ -477,6 +479,25 @@ class CliApp:
CLI applications.
"""

@staticmethod
def _get_base_settings_cls(model_cls: type[Any]) -> type[BaseSettings]:
if issubclass(model_cls, BaseSettings):
return model_cls

class CliAppBaseSettings(BaseSettings, model_cls): # type: ignore
__doc__ = model_cls.__doc__
model_config = SettingsConfigDict(
nested_model_default_partial_update=True,
case_sensitive=True,
cli_hide_none_type=True,
cli_avoid_json=True,
cli_enforce_required=True,
cli_implicit_flags=True,
cli_kebab_case=True,
)

return CliAppBaseSettings

@staticmethod
def _run_cli_cmd(model: Any, cli_cmd_method_name: str, is_required: bool) -> Any:
command = getattr(type(model), cli_cmd_method_name, None)
Expand Down Expand Up @@ -575,22 +596,10 @@ def run(
model_init_data['_cli_exit_on_error'] = cli_exit_on_error
model_init_data['_cli_settings_source'] = cli_settings
if not issubclass(model_cls, BaseSettings):

class CliAppBaseSettings(BaseSettings, model_cls): # type: ignore
__doc__ = model_cls.__doc__
model_config = SettingsConfigDict(
nested_model_default_partial_update=True,
case_sensitive=True,
cli_hide_none_type=True,
cli_avoid_json=True,
cli_enforce_required=True,
cli_implicit_flags=True,
cli_kebab_case=True,
)

model = CliAppBaseSettings(**model_init_data)
base_settings_cls = CliApp._get_base_settings_cls(model_cls)
model = base_settings_cls(**model_init_data)
model_init_data = {}
for field_name, field_info in type(model).model_fields.items():
for field_name, field_info in base_settings_cls.model_fields.items():
model_init_data[_field_name_for_signature(field_name, field_info)] = getattr(model, field_name)

return CliApp._run_cli_cmd(model_cls(**model_init_data), cli_cmd_method_name, is_required=False)
Expand Down Expand Up @@ -619,3 +628,28 @@ def run_subcommand(

subcommand = get_subcommand(model, is_required=True, cli_exit_on_error=cli_exit_on_error)
return CliApp._run_cli_cmd(subcommand, cli_cmd_method_name, is_required=True)

@staticmethod
def serialize(model: PydanticModel) -> list[str]:
"""
Serializes the CLI arguments for a Pydantic data model.

Args:
model: The data model to serialize.

Returns:
The serialized CLI arguments for the data model.
"""

base_settings_cls = CliApp._get_base_settings_cls(type(model))
model_field_definitions: dict[str, Any] = {}
for field_name, field_info in base_settings_cls.model_fields.items():
model_field_definitions[field_name] = (
field_info.annotation,
FieldInfo.merge_field_infos(field_info, default=getattr(model, field_name)),
)

cli_serialize_cls = create_model('CliSerialize', __base__=base_settings_cls, **model_field_definitions)
return CliSettingsSource[Any](
cli_serialize_cls, cli_parse_args=[], root_parser=_CliInternalArgSerializer()
)._serialized_args()
94 changes: 86 additions & 8 deletions pydantic_settings/sources/providers/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)

import typing_extensions
from pydantic import BaseModel, Field
from pydantic import AliasChoices, AliasPath, BaseModel, Field
from pydantic._internal._repr import Representation
from pydantic._internal._utils import is_model_class
from pydantic.dataclasses import is_pydantic_dataclass
Expand Down Expand Up @@ -74,6 +74,10 @@ def error(self, message: str) -> NoReturn:
super().error(message)


class _CliInternalArgSerializer(_CliInternalArgParser):
pass


class CliMutuallyExclusiveGroup(BaseModel):
pass

Expand Down Expand Up @@ -664,6 +668,8 @@ def _parse_known_args(*args: Any, **kwargs: Any) -> Namespace:
self._formatter_class = formatter_class
self._cli_dict_args: dict[str, type[Any] | None] = {}
self._cli_subcommands: defaultdict[str, dict[str, str]] = defaultdict(dict)
self._is_serialize_args = isinstance(root_parser, _CliInternalArgSerializer)
self._serialize_positional_args: dict[str, Any] = {}
self._add_parser_args(
parser=self.root_parser,
model=self.settings_cls,
Expand All @@ -689,6 +695,7 @@ def _add_parser_args(
) -> ArgumentParser:
subparsers: Any = None
alias_path_args: dict[str, str] = {}
alias_path_only_defaults: dict[str, Any] = {}
# Ignore model default if the default is a model and not a subclass of the current model.
model_default = (
None
Expand Down Expand Up @@ -756,9 +763,11 @@ def _add_parser_args(
is_append_action = _annotation_contains_types(
field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True
)
is_parser_submodel = sub_models and not is_append_action
is_parser_submodel = bool(sub_models) and not is_append_action
kwargs: dict[str, Any] = {}
kwargs['default'] = CLI_SUPPRESS
kwargs['default'] = self._get_cli_default_value(
field_name, field_info, model_default, is_parser_submodel
)
kwargs['help'] = self._help_format(field_name, field_info, model_default, is_model_suppressed)
kwargs['metavar'] = self._metavar_format(field_info.annotation)
kwargs['required'] = (
Expand Down Expand Up @@ -817,8 +826,14 @@ def _add_parser_args(
self._add_argument(
parser, *(f'{flag_prefix[: len(name)]}{name}' for name in arg_names), **kwargs
)
elif kwargs['default'] != CLI_SUPPRESS:
self._update_alias_path_only_defaults(
kwargs['dest'], kwargs['default'], field_info, alias_path_only_defaults
)

self._add_parser_alias_paths(parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group)
self._add_parser_alias_paths(
parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group, alias_path_only_defaults
)
return parser

def _check_kebab_name(self, name: str) -> str:
Expand All @@ -845,8 +860,6 @@ def _convert_positional_arg(
) -> tuple[list[str], str]:
flag_prefix = ''
arg_names = [kwargs['dest']]
kwargs['default'] = PydanticUndefined
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())

# Note: CLI positional args are always strictly required at the CLI. Therefore, use field_info.is_required in
# conjunction with model_default instead of the derived kwargs['required'].
Expand All @@ -857,6 +870,13 @@ def _convert_positional_arg(
elif not is_required:
kwargs['nargs'] = '?'

if self._is_serialize_args:
self._serialize_positional_args[kwargs['dest']] = kwargs['default']
kwargs['nargs'] = '*'

kwargs['default'] = PydanticUndefined
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())

del kwargs['dest']
del kwargs['required']
return arg_names, flag_prefix
Expand Down Expand Up @@ -944,7 +964,7 @@ def _add_parser_submodels(
is_model_suppressed = self._is_field_suppressed(field_info) or is_model_suppressed
if is_model_suppressed:
model_group_kwargs['description'] = CLI_SUPPRESS
if not self.cli_avoid_json:
if not self.cli_avoid_json and not self._is_serialize_args:
added_args.append(arg_names[0])
kwargs['nargs'] = '?'
kwargs['const'] = '{}'
Expand Down Expand Up @@ -974,6 +994,7 @@ def _add_parser_alias_paths(
arg_prefix: str,
subcommand_prefix: str,
group: Any,
alias_path_only_defaults: dict[str, Any],
) -> None:
if alias_path_args:
context = parser
Expand All @@ -989,9 +1010,9 @@ def _add_parser_alias_paths(
else f'{arg_prefix.replace(subcommand_prefix, "", 1)}{name}'
)
kwargs: dict[str, Any] = {}
kwargs['default'] = CLI_SUPPRESS
kwargs['help'] = 'pydantic alias path'
kwargs['dest'] = f'{arg_prefix}{name}'
kwargs['default'] = alias_path_only_defaults.get(kwargs['dest'], CLI_SUPPRESS)
if metavar == 'dict' or is_nested_alias_path:
kwargs['metavar'] = 'dict'
else:
Expand Down Expand Up @@ -1084,3 +1105,60 @@ def _help_format(
def _is_field_suppressed(self, field_info: FieldInfo) -> bool:
_help = field_info.description if field_info.description else ''
return _help == CLI_SUPPRESS or CLI_SUPPRESS in field_info.metadata

def _get_cli_default_value(
self, field_name: str, field_info: FieldInfo, model_default: Any, is_parser_submodel: bool
) -> Any:
if is_parser_submodel or not isinstance(self.root_parser, _CliInternalArgSerializer):
return CLI_SUPPRESS

return getattr(model_default, field_name, field_info.default)

def _update_alias_path_only_defaults(
self, dest: str, default: Any, field_info: FieldInfo, alias_path_only_defaults: dict[str, Any]
) -> None:
alias_path: AliasPath = [
alias if isinstance(alias, AliasPath) else cast(AliasPath, alias.choices[0])
for alias in (field_info.alias, field_info.validation_alias)
if isinstance(alias, (AliasPath, AliasChoices))
][0]

alias_nested_paths: list[str] = alias_path.path[1:-1] # type: ignore
if '.' in dest:
alias_nested_paths = dest.split('.') + alias_nested_paths
dest = alias_nested_paths.pop(0)

if not alias_nested_paths:
alias_path_only_defaults.setdefault(dest, [])
alias_default = alias_path_only_defaults[dest]
else:
alias_path_only_defaults.setdefault(dest, {})
current_path = alias_path_only_defaults[dest]

for nested_path in alias_nested_paths[:-1]:
current_path.setdefault(nested_path, {})
current_path = current_path[nested_path]
current_path.setdefault(alias_nested_paths[-1], [])
alias_default = current_path[alias_nested_paths[-1]]

alias_path_index = cast(int, alias_path.path[-1])
alias_default.extend([''] * max(alias_path_index + 1 - len(alias_default), 0))
alias_default[alias_path_index] = default

def _serialized_args(self) -> list[str]:
if not self._is_serialize_args:
raise SettingsError('Root parser is not _CliInternalArgSerializer')

cli_args = []
for arg, values in self._serialize_positional_args.items():
for value in values if isinstance(values, list) else [values]:
value = json.dumps(value) if isinstance(value, (dict, list, set)) else str(value)
cli_args.append(value)

for arg, value in self.env_vars.items():
if arg not in self._serialize_positional_args:
value = json.dumps(value) if isinstance(value, (dict, list, set)) else str(value)
cli_args.append(f'{self.cli_flag_prefix_char * min(len(arg), 2)}{arg}')
cli_args.append(value)

return cli_args
Loading