Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,10 @@ print(User().model_dump())

### Subcommands and Positional Arguments

Subcommands and positional arguments are expressed using the `CliSubCommand` and `CliPositionalArg` annotations. These
annotations can only be applied to required fields (i.e. fields that do not have a default value). Furthermore,
subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses `dataclass`.
Subcommands and positional arguments are expressed using the `CliSubCommand` and `CliPositionalArg` annotations. The
subcommand annotation can only be applied to required fields (i.e. fields that do not have a default value).
Furthermore, subcommands must be a valid type derived from either a pydantic `BaseModel` or pydantic.dataclasses
`dataclass`.

Parsed subcommands can be retrieved from model instances using the `get_subcommand` utility function. If a subcommand is
not required, set the `is_required` flag to `False` to disable raising an error if no subcommand is found.
Expand Down
38 changes: 26 additions & 12 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,11 @@ def _load_env_vars(
if subcommand_dest not in selected_subcommands:
parsed_args[subcommand_dest] = self.cli_parse_none_str

parsed_args = {key: val for key, val in parsed_args.items() if not key.endswith(':subcommand')}
parsed_args = {
key: val
for key, val in parsed_args.items()
if not key.endswith(':subcommand') and val is not PydanticUndefined
}
if selected_subcommands:
last_selected_subcommand = max(selected_subcommands, key=len)
if not any(field_name for field_name in parsed_args.keys() if f'{last_selected_subcommand}.' in field_name):
Expand Down Expand Up @@ -1511,12 +1515,9 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
)
subcommand_args.append((field_name, field_info))
elif _CliPositionalArg in field_info.metadata:
if not field_info.is_required():
raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value')
else:
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases')
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases')
positional_args.append((field_name, field_info))
else:
self._verify_cli_flag_annotations(model, field_name, field_info)
Expand Down Expand Up @@ -1727,11 +1728,7 @@ def _add_parser_args(
self._cli_dict_args[kwargs['dest']] = field_info.annotation

if _CliPositionalArg in field_info.metadata:
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())
arg_names = [kwargs['dest']]
del kwargs['dest']
del kwargs['required']
flag_prefix = ''
arg_names, flag_prefix = self._convert_positional_arg(kwargs, field_info, preferred_alias)

self._convert_bool_flag(kwargs, field_info, model_default)

Expand Down Expand Up @@ -1787,6 +1784,23 @@ def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, mode
BooleanOptionalAction if sys.version_info >= (3, 9) else f'store_{str(not default).lower()}'
)

def _convert_positional_arg(
self, kwargs: dict[str, Any], field_info: FieldInfo, preferred_alias: str
) -> tuple[list[str], str]:
flag_prefix = ''
arg_names = [kwargs['dest']]
kwargs['default'] = PydanticUndefined
kwargs['metavar'] = self._check_kebab_name(preferred_alias.upper())

# Note: For positional args, we must strictly look at field_info.is_required instead of our derived
# kwargs['required'].
if not field_info.is_required():
kwargs['nargs'] = '?'

del kwargs['dest']
del kwargs['required']
return arg_names, flag_prefix

def _get_arg_names(
self,
arg_prefix: str,
Expand Down
25 changes: 16 additions & 9 deletions tests/test_source_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,22 @@ class Cfg(BaseSettings):
assert cfg.model_dump() == {'child': {'name': 'new name a', 'diff_a': 'new diff a'}}


def test_cli_optional_positional_arg(env):
class Main(BaseSettings):
model_config = SettingsConfigDict(
cli_parse_args=True,
cli_enforce_required=True,
)

value: CliPositionalArg[int] = 123

assert CliApp.run(Main, cli_args=[]).model_dump() == {'value': 123}

env.set('VALUE', '456')
assert CliApp.run(Main, cli_args=[]).model_dump() == {'value': 456}

assert CliApp.run(Main, cli_args=['789']).model_dump() == {'value': 789}

def test_cli_enums(capsys, monkeypatch):
class Pet(IntEnum):
dog = 0
Expand Down Expand Up @@ -1415,15 +1431,6 @@ class PositionalArgNotOutermost(BaseSettings, cli_parse_args=True):

PositionalArgNotOutermost()

with pytest.raises(
SettingsError, match='positional argument PositionalArgHasDefault.pos_arg has a default value'
):

class PositionalArgHasDefault(BaseSettings, cli_parse_args=True):
pos_arg: CliPositionalArg[str] = 'bad'

PositionalArgHasDefault()

with pytest.raises(
SettingsError, match=re.escape("cli_parse_args must be List[str] or Tuple[str, ...], recieved <class 'str'>")
):
Expand Down
Loading