Skip to content

Commit 9583896

Browse files
kschwabhramezani
andauthored
Fix alias resolution for default settings source. (#468)
Co-authored-by: Hasan Ramezani <[email protected]>
1 parent b4efcd3 commit 9583896

File tree

2 files changed

+66
-41
lines changed

2 files changed

+66
-41
lines changed

pydantic_settings/sources.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from dotenv import dotenv_values
4040
from pydantic import AliasChoices, AliasPath, BaseModel, Json, RootModel, TypeAdapter
4141
from pydantic._internal._repr import Representation
42-
from pydantic._internal._signature import _field_name_for_signature
4342
from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union, typing_base
4443
from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass
4544
from pydantic.dataclasses import is_pydantic_dataclass
@@ -336,10 +335,12 @@ def __init__(self, settings_cls: type[BaseSettings], nested_model_default_partia
336335
)
337336
if self.nested_model_default_partial_update:
338337
for field_name, field_info in settings_cls.model_fields.items():
338+
alias_names, *_ = _get_alias_names(field_name, field_info)
339+
preferred_alias = alias_names[0]
339340
if is_dataclass(type(field_info.default)):
340-
self.defaults[_field_name_for_signature(field_name, field_info)] = asdict(field_info.default)
341+
self.defaults[preferred_alias] = asdict(field_info.default)
341342
elif is_model_class(type(field_info.default)):
342-
self.defaults[_field_name_for_signature(field_name, field_info)] = field_info.default.model_dump()
343+
self.defaults[preferred_alias] = field_info.default.model_dump()
343344

344345
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
345346
# Nothing to do here. Only implement the return statement to make mypy happy
@@ -1422,41 +1423,6 @@ def _get_sub_models(self, model: type[BaseModel], field_name: str, field_info: F
14221423
sub_models.append(type_) # type: ignore
14231424
return sub_models
14241425

1425-
def _get_alias_names(
1426-
self, field_name: str, field_info: FieldInfo, alias_path_args: dict[str, str]
1427-
) -> tuple[tuple[str, ...], bool]:
1428-
alias_names: list[str] = []
1429-
is_alias_path_only: bool = True
1430-
if not any((field_info.alias, field_info.validation_alias)):
1431-
alias_names += [field_name]
1432-
is_alias_path_only = False
1433-
else:
1434-
new_alias_paths: list[AliasPath] = []
1435-
for alias in (field_info.alias, field_info.validation_alias):
1436-
if alias is None:
1437-
continue
1438-
elif isinstance(alias, str):
1439-
alias_names.append(alias)
1440-
is_alias_path_only = False
1441-
elif isinstance(alias, AliasChoices):
1442-
for name in alias.choices:
1443-
if isinstance(name, str):
1444-
alias_names.append(name)
1445-
is_alias_path_only = False
1446-
else:
1447-
new_alias_paths.append(name)
1448-
else:
1449-
new_alias_paths.append(alias)
1450-
for alias_path in new_alias_paths:
1451-
name = cast(str, alias_path.path[0])
1452-
name = name.lower() if not self.case_sensitive else name
1453-
alias_path_args[name] = 'dict' if len(alias_path.path) > 2 else 'list'
1454-
if not alias_names and is_alias_path_only:
1455-
alias_names.append(name)
1456-
if not self.case_sensitive:
1457-
alias_names = [alias_name.lower() for alias_name in alias_names]
1458-
return tuple(dict.fromkeys(alias_names)), is_alias_path_only
1459-
14601426
def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None:
14611427
if _CliImplicitFlag in field_info.metadata:
14621428
cli_flag_name = 'CliImplicitFlag'
@@ -1481,7 +1447,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
14811447
if not field_info.is_required():
14821448
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value')
14831449
else:
1484-
alias_names, *_ = self._get_alias_names(field_name, field_info, {})
1450+
alias_names, *_ = _get_alias_names(field_name, field_info)
14851451
if len(alias_names) > 1:
14861452
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple aliases')
14871453
field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)]
@@ -1495,7 +1461,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
14951461
if not field_info.is_required():
14961462
raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value')
14971463
else:
1498-
alias_names, *_ = self._get_alias_names(field_name, field_info, {})
1464+
alias_names, *_ = _get_alias_names(field_name, field_info)
14991465
if len(alias_names) > 1:
15001466
raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases')
15011467
positional_args.append((field_name, field_info))
@@ -1597,7 +1563,9 @@ def _add_parser_args(
15971563
alias_path_args: dict[str, str] = {}
15981564
for field_name, field_info in self._sort_arg_fields(model):
15991565
sub_models: list[type[BaseModel]] = self._get_sub_models(model, field_name, field_info)
1600-
alias_names, is_alias_path_only = self._get_alias_names(field_name, field_info, alias_path_args)
1566+
alias_names, is_alias_path_only = _get_alias_names(
1567+
field_name, field_info, alias_path_args=alias_path_args, case_sensitive=self.case_sensitive
1568+
)
16011569
preferred_alias = alias_names[0]
16021570
if _CliSubCommand in field_info.metadata:
16031571
for model in sub_models:
@@ -2241,5 +2209,41 @@ def _get_model_fields(model_cls: type[Any]) -> dict[str, FieldInfo]:
22412209
raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass')
22422210

22432211

2212+
def _get_alias_names(
2213+
field_name: str, field_info: FieldInfo, alias_path_args: dict[str, str] = {}, case_sensitive: bool = True
2214+
) -> tuple[tuple[str, ...], bool]:
2215+
alias_names: list[str] = []
2216+
is_alias_path_only: bool = True
2217+
if not any((field_info.alias, field_info.validation_alias)):
2218+
alias_names += [field_name]
2219+
is_alias_path_only = False
2220+
else:
2221+
new_alias_paths: list[AliasPath] = []
2222+
for alias in (field_info.alias, field_info.validation_alias):
2223+
if alias is None:
2224+
continue
2225+
elif isinstance(alias, str):
2226+
alias_names.append(alias)
2227+
is_alias_path_only = False
2228+
elif isinstance(alias, AliasChoices):
2229+
for name in alias.choices:
2230+
if isinstance(name, str):
2231+
alias_names.append(name)
2232+
is_alias_path_only = False
2233+
else:
2234+
new_alias_paths.append(name)
2235+
else:
2236+
new_alias_paths.append(alias)
2237+
for alias_path in new_alias_paths:
2238+
name = cast(str, alias_path.path[0])
2239+
name = name.lower() if not case_sensitive else name
2240+
alias_path_args[name] = 'dict' if len(alias_path.path) > 2 else 'list'
2241+
if not alias_names and is_alias_path_only:
2242+
alias_names.append(name)
2243+
if not case_sensitive:
2244+
alias_names = [alias_name.lower() for alias_name in alias_names]
2245+
return tuple(dict.fromkeys(alias_names)), is_alias_path_only
2246+
2247+
22442248
def _is_function(obj: Any) -> bool:
22452249
return isinstance(obj, (FunctionType, BuiltinFunctionType))

tests/test_settings.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from annotated_types import MinLen
1414
from pydantic import (
1515
AliasChoices,
16+
AliasGenerator,
1617
AliasPath,
1718
BaseModel,
1819
Discriminator,
@@ -621,6 +622,26 @@ def settings_customise_sources(
621622
assert s.model_dump() == s_final
622623

623624

625+
def test_alias_nested_model_default_partial_update():
626+
class SubModel(BaseModel):
627+
v1: str = 'default'
628+
v2: bytes = b'hello'
629+
v3: int
630+
631+
class Settings(BaseSettings):
632+
model_config = SettingsConfigDict(
633+
nested_model_default_partial_update=True, alias_generator=AliasGenerator(lambda s: s.replace('_', '-'))
634+
)
635+
636+
v0: str = 'ok'
637+
sub_model: SubModel = SubModel(v1='top default', v3=33)
638+
639+
assert Settings(**{'sub-model': {'v1': 'cli'}}).model_dump() == {
640+
'v0': 'ok',
641+
'sub_model': {'v1': 'cli', 'v2': b'hello', 'v3': 33},
642+
}
643+
644+
624645
def test_env_str(env):
625646
class Settings(BaseSettings):
626647
apple: str = Field(None, validation_alias='BOOM')

0 commit comments

Comments
 (0)