Skip to content

Commit 9380bc6

Browse files
authored
CLI Serialize Support (#643)
1 parent 6bae3ab commit 9380bc6

File tree

4 files changed

+224
-24
lines changed

4 files changed

+224
-24
lines changed

docs/index.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,29 @@ CliApp.run(Git, cli_args=['clone', 'repo', 'dir']).model_dump() == {
11731173

11741174
When executing a subcommand with an asynchronous cli_cmd, Pydantic settings automatically detects whether the current thread already has an active event loop. If so, the async command is run in a fresh thread to avoid conflicts. Otherwise, it uses asyncio.run() in the current thread. This handling ensures your asynchronous subcommands "just work" without additional manual setup.
11751175

1176+
### Serializing CLI Arguments
1177+
1178+
An instantiated Pydantic model can be serialized into its CLI arguments using the `CliApp.serialize` method.
1179+
1180+
```py
1181+
from pydantic import BaseModel
1182+
1183+
from pydantic_settings import CliApp
1184+
1185+
1186+
class Nested(BaseModel):
1187+
that: int
1188+
1189+
1190+
class Settings(BaseModel):
1191+
this: str
1192+
nested: Nested
1193+
1194+
1195+
print(CliApp.serialize(Settings(this='hello', nested=Nested(that=123))))
1196+
#> ['--this', 'hello', '--nested.that', '123']
1197+
```
1198+
11761199
### Mutually Exclusive Groups
11771200

11781201
CLI mutually exclusive groups can be created by inheriting from the `CliMutuallyExclusiveGroup` class.

pydantic_settings/main.py

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from types import SimpleNamespace
99
from typing import Any, ClassVar, TypeVar
1010

11-
from pydantic import ConfigDict
11+
from pydantic import ConfigDict, create_model
1212
from pydantic._internal._config import config_keys
1313
from pydantic._internal._signature import _field_name_for_signature
1414
from pydantic._internal._utils import deep_update, is_model_class
1515
from pydantic.dataclasses import is_pydantic_dataclass
16+
from pydantic.fields import FieldInfo
1617
from pydantic.main import BaseModel
1718

1819
from .exceptions import SettingsError
@@ -30,6 +31,7 @@
3031
SecretsSettingsSource,
3132
get_subcommand,
3233
)
34+
from .sources.providers.cli import _CliInternalArgSerializer
3335

3436
T = TypeVar('T')
3537

@@ -477,6 +479,25 @@ class CliApp:
477479
CLI applications.
478480
"""
479481

482+
@staticmethod
483+
def _get_base_settings_cls(model_cls: type[Any]) -> type[BaseSettings]:
484+
if issubclass(model_cls, BaseSettings):
485+
return model_cls
486+
487+
class CliAppBaseSettings(BaseSettings, model_cls): # type: ignore
488+
__doc__ = model_cls.__doc__
489+
model_config = SettingsConfigDict(
490+
nested_model_default_partial_update=True,
491+
case_sensitive=True,
492+
cli_hide_none_type=True,
493+
cli_avoid_json=True,
494+
cli_enforce_required=True,
495+
cli_implicit_flags=True,
496+
cli_kebab_case=True,
497+
)
498+
499+
return CliAppBaseSettings
500+
480501
@staticmethod
481502
def _run_cli_cmd(model: Any, cli_cmd_method_name: str, is_required: bool) -> Any:
482503
command = getattr(type(model), cli_cmd_method_name, None)
@@ -575,22 +596,10 @@ def run(
575596
model_init_data['_cli_exit_on_error'] = cli_exit_on_error
576597
model_init_data['_cli_settings_source'] = cli_settings
577598
if not issubclass(model_cls, BaseSettings):
578-
579-
class CliAppBaseSettings(BaseSettings, model_cls): # type: ignore
580-
__doc__ = model_cls.__doc__
581-
model_config = SettingsConfigDict(
582-
nested_model_default_partial_update=True,
583-
case_sensitive=True,
584-
cli_hide_none_type=True,
585-
cli_avoid_json=True,
586-
cli_enforce_required=True,
587-
cli_implicit_flags=True,
588-
cli_kebab_case=True,
589-
)
590-
591-
model = CliAppBaseSettings(**model_init_data)
599+
base_settings_cls = CliApp._get_base_settings_cls(model_cls)
600+
model = base_settings_cls(**model_init_data)
592601
model_init_data = {}
593-
for field_name, field_info in type(model).model_fields.items():
602+
for field_name, field_info in base_settings_cls.model_fields.items():
594603
model_init_data[_field_name_for_signature(field_name, field_info)] = getattr(model, field_name)
595604

596605
return CliApp._run_cli_cmd(model_cls(**model_init_data), cli_cmd_method_name, is_required=False)
@@ -619,3 +628,28 @@ def run_subcommand(
619628

620629
subcommand = get_subcommand(model, is_required=True, cli_exit_on_error=cli_exit_on_error)
621630
return CliApp._run_cli_cmd(subcommand, cli_cmd_method_name, is_required=True)
631+
632+
@staticmethod
633+
def serialize(model: PydanticModel) -> list[str]:
634+
"""
635+
Serializes the CLI arguments for a Pydantic data model.
636+
637+
Args:
638+
model: The data model to serialize.
639+
640+
Returns:
641+
The serialized CLI arguments for the data model.
642+
"""
643+
644+
base_settings_cls = CliApp._get_base_settings_cls(type(model))
645+
model_field_definitions: dict[str, Any] = {}
646+
for field_name, field_info in base_settings_cls.model_fields.items():
647+
model_field_definitions[field_name] = (
648+
field_info.annotation,
649+
FieldInfo.merge_field_infos(field_info, default=getattr(model, field_name)),
650+
)
651+
652+
cli_serialize_cls = create_model('CliSerialize', __base__=base_settings_cls, **model_field_definitions)
653+
return CliSettingsSource[Any](
654+
cli_serialize_cls, cli_parse_args=[], root_parser=_CliInternalArgSerializer()
655+
)._serialized_args()

pydantic_settings/sources/providers/cli.py

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636

3737
import typing_extensions
38-
from pydantic import BaseModel, Field
38+
from pydantic import AliasChoices, AliasPath, BaseModel, Field
3939
from pydantic._internal._repr import Representation
4040
from pydantic._internal._utils import is_model_class
4141
from pydantic.dataclasses import is_pydantic_dataclass
@@ -74,6 +74,10 @@ def error(self, message: str) -> NoReturn:
7474
super().error(message)
7575

7676

77+
class _CliInternalArgSerializer(_CliInternalArgParser):
78+
pass
79+
80+
7781
class CliMutuallyExclusiveGroup(BaseModel):
7882
pass
7983

@@ -664,6 +668,8 @@ def _parse_known_args(*args: Any, **kwargs: Any) -> Namespace:
664668
self._formatter_class = formatter_class
665669
self._cli_dict_args: dict[str, type[Any] | None] = {}
666670
self._cli_subcommands: defaultdict[str, dict[str, str]] = defaultdict(dict)
671+
self._is_serialize_args = isinstance(root_parser, _CliInternalArgSerializer)
672+
self._serialize_positional_args: dict[str, Any] = {}
667673
self._add_parser_args(
668674
parser=self.root_parser,
669675
model=self.settings_cls,
@@ -689,6 +695,7 @@ def _add_parser_args(
689695
) -> ArgumentParser:
690696
subparsers: Any = None
691697
alias_path_args: dict[str, str] = {}
698+
alias_path_only_defaults: dict[str, Any] = {}
692699
# Ignore model default if the default is a model and not a subclass of the current model.
693700
model_default = (
694701
None
@@ -756,9 +763,11 @@ def _add_parser_args(
756763
is_append_action = _annotation_contains_types(
757764
field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True
758765
)
759-
is_parser_submodel = sub_models and not is_append_action
766+
is_parser_submodel = bool(sub_models) and not is_append_action
760767
kwargs: dict[str, Any] = {}
761-
kwargs['default'] = CLI_SUPPRESS
768+
kwargs['default'] = self._get_cli_default_value(
769+
field_name, field_info, model_default, is_parser_submodel
770+
)
762771
kwargs['help'] = self._help_format(field_name, field_info, model_default, is_model_suppressed)
763772
kwargs['metavar'] = self._metavar_format(field_info.annotation)
764773
kwargs['required'] = (
@@ -817,8 +826,14 @@ def _add_parser_args(
817826
self._add_argument(
818827
parser, *(f'{flag_prefix[: len(name)]}{name}' for name in arg_names), **kwargs
819828
)
829+
elif kwargs['default'] != CLI_SUPPRESS:
830+
self._update_alias_path_only_defaults(
831+
kwargs['dest'], kwargs['default'], field_info, alias_path_only_defaults
832+
)
820833

821-
self._add_parser_alias_paths(parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group)
834+
self._add_parser_alias_paths(
835+
parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group, alias_path_only_defaults
836+
)
822837
return parser
823838

824839
def _check_kebab_name(self, name: str) -> str:
@@ -845,8 +860,6 @@ def _convert_positional_arg(
845860
) -> tuple[list[str], str]:
846861
flag_prefix = ''
847862
arg_names = [kwargs['dest']]
848-
kwargs['default'] = PydanticUndefined
849-
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())
850863

851864
# Note: CLI positional args are always strictly required at the CLI. Therefore, use field_info.is_required in
852865
# conjunction with model_default instead of the derived kwargs['required'].
@@ -857,6 +870,13 @@ def _convert_positional_arg(
857870
elif not is_required:
858871
kwargs['nargs'] = '?'
859872

873+
if self._is_serialize_args:
874+
self._serialize_positional_args[kwargs['dest']] = kwargs['default']
875+
kwargs['nargs'] = '*'
876+
877+
kwargs['default'] = PydanticUndefined
878+
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())
879+
860880
del kwargs['dest']
861881
del kwargs['required']
862882
return arg_names, flag_prefix
@@ -944,7 +964,7 @@ def _add_parser_submodels(
944964
is_model_suppressed = self._is_field_suppressed(field_info) or is_model_suppressed
945965
if is_model_suppressed:
946966
model_group_kwargs['description'] = CLI_SUPPRESS
947-
if not self.cli_avoid_json:
967+
if not self.cli_avoid_json and not self._is_serialize_args:
948968
added_args.append(arg_names[0])
949969
kwargs['nargs'] = '?'
950970
kwargs['const'] = '{}'
@@ -974,6 +994,7 @@ def _add_parser_alias_paths(
974994
arg_prefix: str,
975995
subcommand_prefix: str,
976996
group: Any,
997+
alias_path_only_defaults: dict[str, Any],
977998
) -> None:
978999
if alias_path_args:
9791000
context = parser
@@ -989,9 +1010,9 @@ def _add_parser_alias_paths(
9891010
else f'{arg_prefix.replace(subcommand_prefix, "", 1)}{name}'
9901011
)
9911012
kwargs: dict[str, Any] = {}
992-
kwargs['default'] = CLI_SUPPRESS
9931013
kwargs['help'] = 'pydantic alias path'
9941014
kwargs['dest'] = f'{arg_prefix}{name}'
1015+
kwargs['default'] = alias_path_only_defaults.get(kwargs['dest'], CLI_SUPPRESS)
9951016
if metavar == 'dict' or is_nested_alias_path:
9961017
kwargs['metavar'] = 'dict'
9971018
else:
@@ -1084,3 +1105,60 @@ def _help_format(
10841105
def _is_field_suppressed(self, field_info: FieldInfo) -> bool:
10851106
_help = field_info.description if field_info.description else ''
10861107
return _help == CLI_SUPPRESS or CLI_SUPPRESS in field_info.metadata
1108+
1109+
def _get_cli_default_value(
1110+
self, field_name: str, field_info: FieldInfo, model_default: Any, is_parser_submodel: bool
1111+
) -> Any:
1112+
if is_parser_submodel or not isinstance(self.root_parser, _CliInternalArgSerializer):
1113+
return CLI_SUPPRESS
1114+
1115+
return getattr(model_default, field_name, field_info.default)
1116+
1117+
def _update_alias_path_only_defaults(
1118+
self, dest: str, default: Any, field_info: FieldInfo, alias_path_only_defaults: dict[str, Any]
1119+
) -> None:
1120+
alias_path: AliasPath = [
1121+
alias if isinstance(alias, AliasPath) else cast(AliasPath, alias.choices[0])
1122+
for alias in (field_info.alias, field_info.validation_alias)
1123+
if isinstance(alias, (AliasPath, AliasChoices))
1124+
][0]
1125+
1126+
alias_nested_paths: list[str] = alias_path.path[1:-1] # type: ignore
1127+
if '.' in dest:
1128+
alias_nested_paths = dest.split('.') + alias_nested_paths
1129+
dest = alias_nested_paths.pop(0)
1130+
1131+
if not alias_nested_paths:
1132+
alias_path_only_defaults.setdefault(dest, [])
1133+
alias_default = alias_path_only_defaults[dest]
1134+
else:
1135+
alias_path_only_defaults.setdefault(dest, {})
1136+
current_path = alias_path_only_defaults[dest]
1137+
1138+
for nested_path in alias_nested_paths[:-1]:
1139+
current_path.setdefault(nested_path, {})
1140+
current_path = current_path[nested_path]
1141+
current_path.setdefault(alias_nested_paths[-1], [])
1142+
alias_default = current_path[alias_nested_paths[-1]]
1143+
1144+
alias_path_index = cast(int, alias_path.path[-1])
1145+
alias_default.extend([''] * max(alias_path_index + 1 - len(alias_default), 0))
1146+
alias_default[alias_path_index] = default
1147+
1148+
def _serialized_args(self) -> list[str]:
1149+
if not self._is_serialize_args:
1150+
raise SettingsError('Root parser is not _CliInternalArgSerializer')
1151+
1152+
cli_args = []
1153+
for arg, values in self._serialize_positional_args.items():
1154+
for value in values if isinstance(values, list) else [values]:
1155+
value = json.dumps(value) if isinstance(value, (dict, list, set)) else str(value)
1156+
cli_args.append(value)
1157+
1158+
for arg, value in self.env_vars.items():
1159+
if arg not in self._serialize_positional_args:
1160+
value = json.dumps(value) if isinstance(value, (dict, list, set)) else str(value)
1161+
cli_args.append(f'{self.cli_flag_prefix_char * min(len(arg), 2)}{arg}')
1162+
cli_args.append(value)
1163+
1164+
return cli_args

0 commit comments

Comments
 (0)