diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 09c74560..b6f0f646 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -41,14 +41,15 @@ from dotenv import dotenv_values from pydantic import AliasChoices, AliasPath, BaseModel, Json, RootModel, Secret, TypeAdapter from pydantic._internal._repr import Representation -from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union, typing_base -from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass +from pydantic._internal._utils import deep_update, is_model_class from pydantic.dataclasses import is_pydantic_dataclass from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined -from typing_extensions import _AnnotatedAlias, get_args, get_origin +from typing_extensions import get_args, get_origin +from typing_inspection import typing_objects +from typing_inspection.introspection import is_union_origin -from pydantic_settings.utils import path_type_label +from pydantic_settings.utils import _lenient_issubclass, _WithArgsTypes, path_type_label if TYPE_CHECKING: if sys.version_info >= (3, 11): @@ -484,7 +485,7 @@ def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[s field_info.append((v_alias, self._apply_case_sensitive(v_alias), False)) if not v_alias or self.config.get('populate_by_name', False): - if origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata): + if is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata): field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), True)) else: field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False)) @@ -530,12 +531,13 @@ class Settings(BaseSettings): annotation = field.annotation # If field is Optional, we need to find the actual type - args = get_args(annotation) - if origin_is_union(get_origin(field.annotation)) and len(args) == 2 and type(None) in args: - for arg in args: - if arg is not None: - annotation = arg - break + if is_union_origin(get_origin(field.annotation)): + args = get_args(annotation) + if len(args) == 2 and type(None) in args: + for arg in args: + if arg is not None: + annotation = arg + break # This is here to make mypy happy # Item "None" of "Optional[Type[Any]]" has no attribute "model_fields" @@ -553,7 +555,7 @@ class Settings(BaseSettings): values[name] = value continue - if lenient_issubclass(sub_model_field.annotation, BaseModel) and isinstance(value, dict): + if _lenient_issubclass(sub_model_field.annotation, BaseModel) and isinstance(value, dict): values[sub_model_field_name] = self._replace_field_names_case_insensitively(sub_model_field, value) else: values[sub_model_field_name] = value @@ -623,7 +625,7 @@ def __call__(self) -> dict[str, Any]: field_value = None if ( not self.case_sensitive - # and lenient_issubclass(field.annotation, BaseModel) + # and _lenient_issubclass(field.annotation, BaseModel) and isinstance(field_value, dict) ): data[field_key] = self._replace_field_names_case_insensitively(field, field_value) @@ -842,7 +844,7 @@ def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]: """ if self.field_is_complex(field): allow_parse_failure = False - elif origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata): + elif is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata): allow_parse_failure = True else: return False, False @@ -888,12 +890,11 @@ class Cfg(BaseSettings): return None annotation = field.annotation if isinstance(field, FieldInfo) else field - if origin_is_union(get_origin(annotation)) or isinstance(annotation, WithArgsTypes): - for type_ in get_args(annotation): - type_has_key = self.next_field(type_, key, case_sensitive) - if type_has_key: - return type_has_key - elif is_model_class(annotation) or is_pydantic_dataclass(annotation): + for type_ in get_args(annotation): + type_has_key = self.next_field(type_, key, case_sensitive) + if type_has_key: + return type_has_key + if is_model_class(annotation) or is_pydantic_dataclass(annotation): fields = _get_model_fields(annotation) # `case_sensitive is None` is here to be compatible with the old behavior. # Has to be removed in V3. @@ -923,7 +924,8 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[ if not self.env_nested_delimiter: return {} - is_dict = lenient_issubclass(get_origin(field.annotation), dict) + ann = field.annotation + is_dict = ann is dict or _lenient_issubclass(get_origin(ann), dict) prefixes = [ f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name) @@ -1065,7 +1067,7 @@ def __call__(self) -> dict[str, Any]: ( _annotation_is_complex(field.annotation, field.metadata) or ( - origin_is_union(get_origin(field.annotation)) + is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata) ) ) @@ -1382,7 +1384,7 @@ def _get_merge_parsed_list_types( merge_type = self._cli_dict_args.get(field_name, list) if ( merge_type is list - or not origin_is_union(get_origin(merge_type)) + or not is_union_origin(get_origin(merge_type)) or not any( type_ for type_ in get_args(merge_type) @@ -1526,7 +1528,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo] alias_names, *_ = _get_alias_names(field_name, field_info) if len(alias_names) > 1: raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple aliases') - field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)] + field_types = (type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)) for field_type in field_types: if not (is_model_class(field_type) or is_pydantic_dataclass(field_type)): raise SettingsError( @@ -1984,19 +1986,20 @@ def _metavar_format_recurse(self, obj: Any) -> str: return '...' elif isinstance(obj, Representation): return repr(obj) - elif isinstance(obj, typing_extensions.TypeAliasType): + elif typing_objects.is_typealiastype(obj): return str(obj) - if not isinstance(obj, (typing_base, WithArgsTypes, type)): + origin = get_origin(obj) + if origin is None and not isinstance(obj, (type, typing.ForwardRef, typing_extensions.ForwardRef)): obj = obj.__class__ - if origin_is_union(get_origin(obj)): + if is_union_origin(origin): return self._metavar_format_choices(list(map(self._metavar_format_recurse, self._get_modified_args(obj)))) - elif get_origin(obj) in (typing_extensions.Literal, typing.Literal): + elif typing_objects.is_literal(origin): return self._metavar_format_choices(list(map(str, self._get_modified_args(obj)))) - elif lenient_issubclass(obj, Enum): + elif _lenient_issubclass(obj, Enum): return self._metavar_format_choices([val.name for val in obj]) - elif isinstance(obj, WithArgsTypes): + elif isinstance(obj, _WithArgsTypes): return self._metavar_format_choices( list(map(self._metavar_format_recurse, self._get_modified_args(obj))), obj_qualname=obj.__qualname__ if hasattr(obj, '__qualname__') else str(obj), @@ -2292,25 +2295,22 @@ def read_env_file( def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool: # If the model is a root model, the root annotation should be used to # evaluate the complexity. - try: - if annotation is not None and issubclass(annotation, RootModel): - # In some rare cases (see test_root_model_as_field), - # the root attribute is not available. For these cases, python 3.8 and 3.9 - # return 'RootModelRootType'. - root_annotation = annotation.__annotations__.get('root', None) - if root_annotation is not None and root_annotation != 'RootModelRootType': - annotation = root_annotation - except TypeError: - pass + if annotation is not None and _lenient_issubclass(annotation, RootModel) and annotation is not RootModel: + annotation = cast('type[RootModel[Any]]', annotation) + root_annotation = annotation.model_fields['root'].annotation + if root_annotation is not None: + annotation = root_annotation if any(isinstance(md, Json) for md in metadata): # type: ignore[misc] return False + + origin = get_origin(annotation) + # Check if annotation is of the form Annotated[type, metadata]. - if isinstance(annotation, _AnnotatedAlias): + if typing_objects.is_annotated(origin): # Return result of recursive call on inner type. inner, *meta = get_args(annotation) return _annotation_is_complex(inner, meta) - origin = get_origin(annotation) if origin is Secret: return False @@ -2324,12 +2324,12 @@ def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool: - if lenient_issubclass(annotation, (str, bytes)): + if _lenient_issubclass(annotation, (str, bytes)): return False - return lenient_issubclass(annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)) or is_dataclass( - annotation - ) + return _lenient_issubclass( + annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque) + ) or is_dataclass(annotation) def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool: @@ -2353,14 +2353,15 @@ def _annotation_contains_types( def _strip_annotated(annotation: Any) -> Any: - while get_origin(annotation) == Annotated: - annotation = get_args(annotation)[0] - return annotation + if typing_objects.is_annotated(get_origin(annotation)): + return annotation.__origin__ + else: + return annotation def _annotation_enum_val_to_name(annotation: type[Any] | None, value: Any) -> Optional[str]: for type_ in (annotation, get_origin(annotation), *get_args(annotation)): - if lenient_issubclass(type_, Enum): + if _lenient_issubclass(type_, Enum): if value in tuple(val.value for val in type_): return type_(value).name return None @@ -2368,7 +2369,7 @@ def _annotation_enum_val_to_name(annotation: type[Any] | None, value: Any) -> Op def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any: for type_ in (annotation, get_origin(annotation), *get_args(annotation)): - if lenient_issubclass(type_, Enum): + if _lenient_issubclass(type_, Enum): if name in tuple(val.name for val in type_): return type_[name] return None diff --git a/pydantic_settings/utils.py b/pydantic_settings/utils.py index 35dceea2..d4326b59 100644 --- a/pydantic_settings/utils.py +++ b/pydantic_settings/utils.py @@ -1,4 +1,9 @@ +import sys +import types from pathlib import Path +from typing import Any, _GenericAlias # type: ignore [attr-defined] + +from typing_extensions import get_origin _PATH_TYPE_LABELS = { Path.is_dir: 'directory', @@ -22,3 +27,22 @@ def path_type_label(p: Path) -> str: return name return 'unknown' + + +# TODO remove and replace usage by `isinstance(cls, type) and issubclass(cls, class_or_tuple)` +# once we drop support for Python 3.10. +def _lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover + try: + return isinstance(cls, type) and issubclass(cls, class_or_tuple) + except TypeError: + if get_origin(cls) is not None: + # Up until Python 3.10, isinstance(, type) is True + # (e.g. list[int]) + return False + raise + + +if sys.version_info < (3, 10): + _WithArgsTypes = tuple() +else: + _WithArgsTypes = (_GenericAlias, types.GenericAlias, types.UnionType) diff --git a/pyproject.toml b/pyproject.toml index ee1d7131..d3b1c8f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ requires-python = '>=3.9' dependencies = [ 'pydantic>=2.7.0', 'python-dotenv>=0.21.0', + 'typing-inspection>=0.4.0', ] dynamic = ['version'] diff --git a/requirements/pyproject.txt b/requirements/pyproject.txt index ace5a5c2..e94dce4b 100644 --- a/requirements/pyproject.txt +++ b/requirements/pyproject.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.13 # by the following command: # # pip-compile --extra=azure-key-vault --extra=toml --extra=yaml --no-emit-index-url --output-file=requirements/pyproject.txt pyproject.toml @@ -63,11 +63,13 @@ tomli==2.0.1 # via pydantic-settings (pyproject.toml) typing-extensions==4.12.2 # via - # annotated-types # azure-core # azure-identity # azure-keyvault-secrets # pydantic # pydantic-core + # typing-inspection +typing-inspection==0.4.0 + # via pydantic-settings (pyproject.toml) urllib3==2.2.2 # via requests