Skip to content

Commit 9c6c9b5

Browse files
authored
Cli root model support (#677)
1 parent a164b73 commit 9c6c9b5

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

pydantic_settings/sources/providers/cli.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ def _add_parser_args(
956956

957957
self._convert_bool_flag(arg.kwargs, field_info, model_default)
958958

959-
if arg.is_parser_submodel:
959+
if arg.is_parser_submodel and not getattr(field_info.annotation, '__pydantic_root_model__', False):
960960
self._add_parser_submodels(
961961
parser,
962962
model,
@@ -1107,6 +1107,7 @@ def _add_parser_submodels(
11071107
model_group_kwargs['description'] = CLI_SUPPRESS
11081108
if not self.cli_avoid_json:
11091109
added_args.append(arg_names[0])
1110+
kwargs['required'] = False
11101111
kwargs['nargs'] = '?'
11111112
kwargs['const'] = '{}'
11121113
kwargs['help'] = (
@@ -1205,8 +1206,12 @@ def _metavar_format_recurse(self, obj: Any) -> str:
12051206
)
12061207
elif obj is type(None):
12071208
return self.cli_parse_none_str
1208-
elif is_model_class(obj):
1209-
return 'JSON'
1209+
elif is_model_class(obj) or is_pydantic_dataclass(obj):
1210+
return (
1211+
self._metavar_format_recurse(_get_model_fields(obj)['root'].annotation)
1212+
if getattr(obj, '__pydantic_root_model__', False)
1213+
else 'JSON'
1214+
)
12101215
elif isinstance(obj, type):
12111216
return obj.__qualname__
12121217
else:

tests/test_source_cli.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
DirectoryPath,
2020
Discriminator,
2121
Field,
22+
RootModel,
2223
Tag,
2324
ValidationError,
2425
field_validator,
@@ -1778,20 +1779,31 @@ class Settings(BaseSettings):
17781779

17791780

17801781
def test_cli_enforce_required(env):
1782+
class MyRootModel(RootModel[str]):
1783+
root: str
1784+
17811785
class Settings(BaseSettings, cli_exit_on_error=False):
17821786
my_required_field: str
1787+
my_root_model_required_field: MyRootModel
17831788

17841789
env.set('MY_REQUIRED_FIELD', 'hello from environment')
1790+
env.set('MY_ROOT_MODEL_REQUIRED_FIELD', 'hi from environment')
17851791

17861792
assert Settings(_cli_parse_args=[], _cli_enforce_required=False).model_dump() == {
1787-
'my_required_field': 'hello from environment'
1793+
'my_required_field': 'hello from environment',
1794+
'my_root_model_required_field': 'hi from environment',
17881795
}
17891796

17901797
with pytest.raises(
17911798
SettingsError, match='error parsing CLI: the following arguments are required: --my_required_field'
17921799
):
17931800
Settings(_cli_parse_args=[], _cli_enforce_required=True).model_dump()
17941801

1802+
with pytest.raises(
1803+
SettingsError, match='error parsing CLI: the following arguments are required: --my_root_model_required_field'
1804+
):
1805+
Settings(_cli_parse_args=['--my_required_field', 'hello from cli'], _cli_enforce_required=True).model_dump()
1806+
17951807

17961808
def test_cli_exit_on_error(capsys, monkeypatch):
17971809
class Settings(BaseSettings, cli_parse_args=True): ...

0 commit comments

Comments
 (0)