Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 52 additions & 51 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@
from dotenv import dotenv_values
from pydantic import AliasChoices, AliasPath, BaseModel, Json, RootModel, Secret, TypeAdapter
from pydantic._internal._repr import Representation
from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union, typing_base
from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass
from pydantic._internal._utils import deep_update, is_model_class
from pydantic.dataclasses import is_pydantic_dataclass
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from typing_extensions import _AnnotatedAlias, get_args, get_origin
from typing_extensions import get_args, get_origin
from typing_inspection import typing_objects
from typing_inspection.introspection import is_union_origin

from pydantic_settings.utils import path_type_label
from pydantic_settings.utils import _lenient_issubclass, _WithArgsTypes, path_type_label

if TYPE_CHECKING:
if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -484,7 +485,7 @@ def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[s
field_info.append((v_alias, self._apply_case_sensitive(v_alias), False))

if not v_alias or self.config.get('populate_by_name', False):
if origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
if is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), True))
else:
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False))
Expand Down Expand Up @@ -530,12 +531,13 @@ class Settings(BaseSettings):
annotation = field.annotation

# If field is Optional, we need to find the actual type
args = get_args(annotation)
if origin_is_union(get_origin(field.annotation)) and len(args) == 2 and type(None) in args:
for arg in args:
if arg is not None:
annotation = arg
break
if is_union_origin(get_origin(field.annotation)):
args = get_args(annotation)
if len(args) == 2 and type(None) in args:
for arg in args:
if arg is not None:
annotation = arg
break

# This is here to make mypy happy
# Item "None" of "Optional[Type[Any]]" has no attribute "model_fields"
Expand All @@ -553,7 +555,7 @@ class Settings(BaseSettings):
values[name] = value
continue

if lenient_issubclass(sub_model_field.annotation, BaseModel) and isinstance(value, dict):
if _lenient_issubclass(sub_model_field.annotation, BaseModel) and isinstance(value, dict):
values[sub_model_field_name] = self._replace_field_names_case_insensitively(sub_model_field, value)
else:
values[sub_model_field_name] = value
Expand Down Expand Up @@ -623,7 +625,7 @@ def __call__(self) -> dict[str, Any]:
field_value = None
if (
not self.case_sensitive
# and lenient_issubclass(field.annotation, BaseModel)
# and _lenient_issubclass(field.annotation, BaseModel)
and isinstance(field_value, dict)
):
data[field_key] = self._replace_field_names_case_insensitively(field, field_value)
Expand Down Expand Up @@ -842,7 +844,7 @@ def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
"""
if self.field_is_complex(field):
allow_parse_failure = False
elif origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
elif is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
allow_parse_failure = True
else:
return False, False
Expand Down Expand Up @@ -888,12 +890,11 @@ class Cfg(BaseSettings):
return None

annotation = field.annotation if isinstance(field, FieldInfo) else field
if origin_is_union(get_origin(annotation)) or isinstance(annotation, WithArgsTypes):
for type_ in get_args(annotation):
type_has_key = self.next_field(type_, key, case_sensitive)
if type_has_key:
return type_has_key
elif is_model_class(annotation) or is_pydantic_dataclass(annotation):
for type_ in get_args(annotation):
type_has_key = self.next_field(type_, key, case_sensitive)
if type_has_key:
return type_has_key
if is_model_class(annotation) or is_pydantic_dataclass(annotation):
fields = _get_model_fields(annotation)
# `case_sensitive is None` is here to be compatible with the old behavior.
# Has to be removed in V3.
Expand Down Expand Up @@ -923,7 +924,8 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[
if not self.env_nested_delimiter:
return {}

is_dict = lenient_issubclass(get_origin(field.annotation), dict)
ann = field.annotation
is_dict = ann is dict or _lenient_issubclass(get_origin(ann), dict)

prefixes = [
f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name)
Expand Down Expand Up @@ -1065,7 +1067,7 @@ def __call__(self) -> dict[str, Any]:
(
_annotation_is_complex(field.annotation, field.metadata)
or (
origin_is_union(get_origin(field.annotation))
is_union_origin(get_origin(field.annotation))
and _union_is_complex(field.annotation, field.metadata)
)
)
Expand Down Expand Up @@ -1382,7 +1384,7 @@ def _get_merge_parsed_list_types(
merge_type = self._cli_dict_args.get(field_name, list)
if (
merge_type is list
or not origin_is_union(get_origin(merge_type))
or not is_union_origin(get_origin(merge_type))
or not any(
type_
for type_ in get_args(merge_type)
Expand Down Expand Up @@ -1526,7 +1528,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple aliases')
field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)]
field_types = (type_ for type_ in get_args(field_info.annotation) if type_ is not type(None))
for field_type in field_types:
if not (is_model_class(field_type) or is_pydantic_dataclass(field_type)):
raise SettingsError(
Expand Down Expand Up @@ -1984,19 +1986,20 @@ def _metavar_format_recurse(self, obj: Any) -> str:
return '...'
elif isinstance(obj, Representation):
return repr(obj)
elif isinstance(obj, typing_extensions.TypeAliasType):
elif typing_objects.is_typealiastype(obj):
return str(obj)

if not isinstance(obj, (typing_base, WithArgsTypes, type)):
origin = get_origin(obj)
Copy link
Member Author

@Viicos Viicos Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is still a big mess, in the future I'd like to add a string repr function in typing-inspection

if origin is None and not isinstance(obj, (type, typing.ForwardRef, typing_extensions.ForwardRef)):
obj = obj.__class__

if origin_is_union(get_origin(obj)):
if is_union_origin(origin):
return self._metavar_format_choices(list(map(self._metavar_format_recurse, self._get_modified_args(obj))))
elif get_origin(obj) in (typing_extensions.Literal, typing.Literal):
elif typing_objects.is_literal(origin):
return self._metavar_format_choices(list(map(str, self._get_modified_args(obj))))
elif lenient_issubclass(obj, Enum):
elif _lenient_issubclass(obj, Enum):
return self._metavar_format_choices([val.name for val in obj])
elif isinstance(obj, WithArgsTypes):
elif isinstance(obj, _WithArgsTypes):
return self._metavar_format_choices(
list(map(self._metavar_format_recurse, self._get_modified_args(obj))),
obj_qualname=obj.__qualname__ if hasattr(obj, '__qualname__') else str(obj),
Expand Down Expand Up @@ -2292,25 +2295,22 @@ def read_env_file(
def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
# If the model is a root model, the root annotation should be used to
# evaluate the complexity.
try:
if annotation is not None and issubclass(annotation, RootModel):
# In some rare cases (see test_root_model_as_field),
# the root attribute is not available. For these cases, python 3.8 and 3.9
# return 'RootModelRootType'.
root_annotation = annotation.__annotations__.get('root', None)
if root_annotation is not None and root_annotation != 'RootModelRootType':
annotation = root_annotation
except TypeError:
pass
if annotation is not None and _lenient_issubclass(annotation, RootModel) and annotation is not RootModel:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_lenient_issubclass takes care of the TypeError.

Also, checking for annotation is not RootModel and relying on model_fields['root'].annotation is more robust

annotation = cast('type[RootModel[Any]]', annotation)
root_annotation = annotation.model_fields['root'].annotation
if root_annotation is not None:
annotation = root_annotation

if any(isinstance(md, Json) for md in metadata): # type: ignore[misc]
return False

origin = get_origin(annotation)

# Check if annotation is of the form Annotated[type, metadata].
if isinstance(annotation, _AnnotatedAlias):
if typing_objects.is_annotated(origin):
# Return result of recursive call on inner type.
inner, *meta = get_args(annotation)
return _annotation_is_complex(inner, meta)
origin = get_origin(annotation)

if origin is Secret:
return False
Expand All @@ -2324,12 +2324,12 @@ def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) ->


def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
if _lenient_issubclass(annotation, (str, bytes)):
return False

return lenient_issubclass(annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)) or is_dataclass(
annotation
)
return _lenient_issubclass(
annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)
) or is_dataclass(annotation)


def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
Expand All @@ -2353,22 +2353,23 @@ def _annotation_contains_types(


def _strip_annotated(annotation: Any) -> Any:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python takes care of flattening Annotated forms already.

while get_origin(annotation) == Annotated:
annotation = get_args(annotation)[0]
return annotation
if typing_objects.is_annotated(get_origin(annotation)):
return annotation.__origin__
else:
return annotation


def _annotation_enum_val_to_name(annotation: type[Any] | None, value: Any) -> Optional[str]:
for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
if lenient_issubclass(type_, Enum):
if _lenient_issubclass(type_, Enum):
if value in tuple(val.value for val in type_):
return type_(value).name
return None


def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any:
for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
if lenient_issubclass(type_, Enum):
if _lenient_issubclass(type_, Enum):
if name in tuple(val.name for val in type_):
return type_[name]
return None
Expand Down
24 changes: 24 additions & 0 deletions pydantic_settings/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import sys
import types
from pathlib import Path
from typing import Any, _GenericAlias # type: ignore [attr-defined]

from typing_extensions import get_origin

_PATH_TYPE_LABELS = {
Path.is_dir: 'directory',
Expand All @@ -22,3 +27,22 @@ def path_type_label(p: Path) -> str:
return name

return 'unknown'


# TODO remove and replace usage by `isinstance(cls, type) and issubclass(cls, class_or_tuple)`
# once we drop support for Python 3.10.
def _lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
except TypeError:
if get_origin(cls) is not None:
# Up until Python 3.10, isinstance(<generic_alias>, type) is True
# (e.g. list[int])
return False
raise


if sys.version_info < (3, 10):
_WithArgsTypes = tuple()
else:
_WithArgsTypes = (_GenericAlias, types.GenericAlias, types.UnionType)
Comment on lines +45 to +48
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is taken from Pydantic 2.9. The < (3, 10) branch was actually wrong (it was missing _GenericAlias), and pydantic-settings relied on this inconsistency.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this was reported here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm keeping the inconsistency here because the metavar format function is really messy and I can't reason properly when looking into it. It needs to be refactored and as mentioned in another comment this will probably be done in typing-inspection.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ requires-python = '>=3.9'
dependencies = [
'pydantic>=2.7.0',
'python-dotenv>=0.21.0',
'typing-inspection>=0.4.0',
]
dynamic = ['version']

Expand Down
6 changes: 4 additions & 2 deletions requirements/pyproject.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with Python 3.8
# This file is autogenerated by pip-compile with Python 3.13
# by the following command:
#
# pip-compile --extra=azure-key-vault --extra=toml --extra=yaml --no-emit-index-url --output-file=requirements/pyproject.txt pyproject.toml
Expand Down Expand Up @@ -63,11 +63,13 @@ tomli==2.0.1
# via pydantic-settings (pyproject.toml)
typing-extensions==4.12.2
# via
# annotated-types
# azure-core
# azure-identity
# azure-keyvault-secrets
# pydantic
# pydantic-core
# typing-inspection
typing-inspection==0.4.0
# via pydantic-settings (pyproject.toml)
urllib3==2.2.2
# via requests