99import warnings
1010from abc import ABC , abstractmethod
1111
12+ import typing_extensions
13+
1214if sys .version_info >= (3 , 9 ):
1315 from argparse import BooleanOptionalAction
1416from argparse import SUPPRESS , ArgumentParser , Namespace , RawDescriptionHelpFormatter , _SubParsersAction
3537 overload ,
3638)
3739
38- import typing_extensions
3940from dotenv import dotenv_values
4041from pydantic import AliasChoices , AliasPath , BaseModel , Json , RootModel , Secret , TypeAdapter
4142from pydantic ._internal ._repr import Representation
42- from pydantic ._internal ._typing_extra import WithArgsTypes , origin_is_union , typing_base
43- from pydantic ._internal ._utils import deep_update , is_model_class , lenient_issubclass
43+ from pydantic ._internal ._utils import deep_update , is_model_class
4444from pydantic .dataclasses import is_pydantic_dataclass
4545from pydantic .fields import FieldInfo
4646from pydantic_core import PydanticUndefined
47- from typing_extensions import Annotated , _AnnotatedAlias , get_args , get_origin
47+ from typing_extensions import Annotated , get_args , get_origin
48+ from typing_inspection import typing_objects
49+ from typing_inspection .introspection import is_union_origin
4850
49- from pydantic_settings .utils import path_type_label
51+ from pydantic_settings .utils import _lenient_issubclass , _WithArgsTypes , path_type_label
5052
5153if TYPE_CHECKING :
5254 if sys .version_info >= (3 , 11 ):
@@ -482,7 +484,7 @@ def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[s
482484 field_info .append ((v_alias , self ._apply_case_sensitive (v_alias ), False ))
483485
484486 if not v_alias or self .config .get ('populate_by_name' , False ):
485- if origin_is_union (get_origin (field .annotation )) and _union_is_complex (field .annotation , field .metadata ):
487+ if is_union_origin (get_origin (field .annotation )) and _union_is_complex (field .annotation , field .metadata ):
486488 field_info .append ((field_name , self ._apply_case_sensitive (self .env_prefix + field_name ), True ))
487489 else :
488490 field_info .append ((field_name , self ._apply_case_sensitive (self .env_prefix + field_name ), False ))
@@ -528,12 +530,13 @@ class Settings(BaseSettings):
528530 annotation = field .annotation
529531
530532 # If field is Optional, we need to find the actual type
531- args = get_args (annotation )
532- if origin_is_union (get_origin (field .annotation )) and len (args ) == 2 and type (None ) in args :
533- for arg in args :
534- if arg is not None :
535- annotation = arg
536- break
533+ if is_union_origin (get_origin (field .annotation )):
534+ args = get_args (annotation )
535+ if len (args ) == 2 and type (None ) in args :
536+ for arg in args :
537+ if arg is not None :
538+ annotation = arg
539+ break
537540
538541 # This is here to make mypy happy
539542 # Item "None" of "Optional[Type[Any]]" has no attribute "model_fields"
@@ -551,7 +554,7 @@ class Settings(BaseSettings):
551554 values [name ] = value
552555 continue
553556
554- if lenient_issubclass (sub_model_field .annotation , BaseModel ) and isinstance (value , dict ):
557+ if _lenient_issubclass (sub_model_field .annotation , BaseModel ) and isinstance (value , dict ):
555558 values [sub_model_field_name ] = self ._replace_field_names_case_insensitively (sub_model_field , value )
556559 else :
557560 values [sub_model_field_name ] = value
@@ -621,7 +624,7 @@ def __call__(self) -> dict[str, Any]:
621624 field_value = None
622625 if (
623626 not self .case_sensitive
624- # and lenient_issubclass (field.annotation, BaseModel)
627+ # and _lenient_issubclass (field.annotation, BaseModel)
625628 and isinstance (field_value , dict )
626629 ):
627630 data [field_key ] = self ._replace_field_names_case_insensitively (field , field_value )
@@ -840,7 +843,7 @@ def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
840843 """
841844 if self .field_is_complex (field ):
842845 allow_parse_failure = False
843- elif origin_is_union (get_origin (field .annotation )) and _union_is_complex (field .annotation , field .metadata ):
846+ elif is_union_origin (get_origin (field .annotation )) and _union_is_complex (field .annotation , field .metadata ):
844847 allow_parse_failure = True
845848 else :
846849 return False , False
@@ -886,12 +889,11 @@ class Cfg(BaseSettings):
886889 return None
887890
888891 annotation = field .annotation if isinstance (field , FieldInfo ) else field
889- if origin_is_union (get_origin (annotation )) or isinstance (annotation , WithArgsTypes ):
890- for type_ in get_args (annotation ):
891- type_has_key = self .next_field (type_ , key , case_sensitive )
892- if type_has_key :
893- return type_has_key
894- elif is_model_class (annotation ) or is_pydantic_dataclass (annotation ):
892+ for type_ in get_args (annotation ):
893+ type_has_key = self .next_field (type_ , key , case_sensitive )
894+ if type_has_key :
895+ return type_has_key
896+ if is_model_class (annotation ) or is_pydantic_dataclass (annotation ):
895897 fields = _get_model_fields (annotation )
896898 # `case_sensitive is None` is here to be compatible with the old behavior.
897899 # Has to be removed in V3.
@@ -921,7 +923,8 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[
921923 if not self .env_nested_delimiter :
922924 return {}
923925
924- is_dict = lenient_issubclass (get_origin (field .annotation ), dict )
926+ ann = field .annotation
927+ is_dict = ann is dict or _lenient_issubclass (get_origin (ann ), dict )
925928
926929 prefixes = [
927930 f'{ env_name } { self .env_nested_delimiter } ' for _ , env_name , _ in self ._extract_field_info (field , field_name )
@@ -1063,7 +1066,7 @@ def __call__(self) -> dict[str, Any]:
10631066 (
10641067 _annotation_is_complex (field .annotation , field .metadata )
10651068 or (
1066- origin_is_union (get_origin (field .annotation ))
1069+ is_union_origin (get_origin (field .annotation ))
10671070 and _union_is_complex (field .annotation , field .metadata )
10681071 )
10691072 )
@@ -1380,7 +1383,7 @@ def _get_merge_parsed_list_types(
13801383 merge_type = self ._cli_dict_args .get (field_name , list )
13811384 if (
13821385 merge_type is list
1383- or not origin_is_union (get_origin (merge_type ))
1386+ or not is_union_origin (get_origin (merge_type ))
13841387 or not any (
13851388 type_
13861389 for type_ in get_args (merge_type )
@@ -1512,9 +1515,7 @@ def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str,
15121515
15131516 if field_info .annotation is not bool :
15141517 raise SettingsError (f'{ cli_flag_name } argument { model .__name__ } .{ field_name } is not of type bool' )
1515- elif sys .version_info < (3 , 9 ) and (
1516- field_info .default is PydanticUndefined and field_info .default_factory is None
1517- ):
1518+ elif sys .version_info < (3 , 9 ) and field_info .is_required ():
15181519 raise SettingsError (
15191520 f'{ cli_flag_name } argument { model .__name__ } .{ field_name } must have default for python versions < 3.9'
15201521 )
@@ -1530,7 +1531,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
15301531 alias_names , * _ = _get_alias_names (field_name , field_info )
15311532 if len (alias_names ) > 1 :
15321533 raise SettingsError (f'subcommand argument { model .__name__ } .{ field_name } has multiple aliases' )
1533- field_types = [ type_ for type_ in get_args (field_info .annotation ) if type_ is not type (None )]
1534+ field_types = ( type_ for type_ in get_args (field_info .annotation ) if type_ is not type (None ))
15341535 for field_type in field_types :
15351536 if not (is_model_class (field_type ) or is_pydantic_dataclass (field_type )):
15361537 raise SettingsError (
@@ -1996,19 +1997,20 @@ def _metavar_format_recurse(self, obj: Any) -> str:
19961997 return '...'
19971998 elif isinstance (obj , Representation ):
19981999 return repr (obj )
1999- elif isinstance (obj , typing_extensions . TypeAliasType ):
2000+ elif typing_objects . is_typealiastype (obj ):
20002001 return str (obj )
20012002
2002- if not isinstance (obj , (typing_base , WithArgsTypes , type )):
2003+ origin = get_origin (obj )
2004+ if origin is None and not isinstance (obj , (type , typing .ForwardRef , typing_extensions .ForwardRef )):
20032005 obj = obj .__class__
20042006
2005- if origin_is_union ( get_origin ( obj ) ):
2007+ if is_union_origin ( origin ):
20062008 return self ._metavar_format_choices (list (map (self ._metavar_format_recurse , self ._get_modified_args (obj ))))
2007- elif get_origin ( obj ) in ( typing_extensions . Literal , typing . Literal ):
2009+ elif typing_objects . is_literal ( origin ):
20082010 return self ._metavar_format_choices (list (map (str , self ._get_modified_args (obj ))))
2009- elif lenient_issubclass (obj , Enum ):
2011+ elif _lenient_issubclass (obj , Enum ):
20102012 return self ._metavar_format_choices ([val .name for val in obj ])
2011- elif isinstance (obj , WithArgsTypes ):
2013+ elif isinstance (obj , _WithArgsTypes ):
20122014 return self ._metavar_format_choices (
20132015 list (map (self ._metavar_format_recurse , self ._get_modified_args (obj ))),
20142016 obj_qualname = obj .__qualname__ if hasattr (obj , '__qualname__' ) else str (obj ),
@@ -2304,25 +2306,22 @@ def read_env_file(
23042306def _annotation_is_complex (annotation : type [Any ] | None , metadata : list [Any ]) -> bool :
23052307 # If the model is a root model, the root annotation should be used to
23062308 # evaluate the complexity.
2307- try :
2308- if annotation is not None and issubclass (annotation , RootModel ):
2309- # In some rare cases (see test_root_model_as_field),
2310- # the root attribute is not available. For these cases, python 3.8 and 3.9
2311- # return 'RootModelRootType'.
2312- root_annotation = annotation .__annotations__ .get ('root' , None )
2313- if root_annotation is not None and root_annotation != 'RootModelRootType' :
2314- annotation = root_annotation
2315- except TypeError :
2316- pass
2309+ if annotation is not None and _lenient_issubclass (annotation , RootModel ) and annotation is not RootModel :
2310+ annotation = cast ('type[RootModel[Any]]' , annotation )
2311+ root_annotation = annotation .model_fields ['root' ].annotation
2312+ if root_annotation is not None :
2313+ annotation = root_annotation
23172314
23182315 if any (isinstance (md , Json ) for md in metadata ): # type: ignore[misc]
23192316 return False
2317+
2318+ origin = get_origin (annotation )
2319+
23202320 # Check if annotation is of the form Annotated[type, metadata].
2321- if isinstance ( annotation , _AnnotatedAlias ):
2321+ if typing_objects . is_annotated ( origin ):
23222322 # Return result of recursive call on inner type.
23232323 inner , * meta = get_args (annotation )
23242324 return _annotation_is_complex (inner , meta )
2325- origin = get_origin (annotation )
23262325
23272326 if origin is Secret :
23282327 return False
@@ -2336,12 +2335,12 @@ def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) ->
23362335
23372336
23382337def _annotation_is_complex_inner (annotation : type [Any ] | None ) -> bool :
2339- if lenient_issubclass (annotation , (str , bytes )):
2338+ if _lenient_issubclass (annotation , (str , bytes )):
23402339 return False
23412340
2342- return lenient_issubclass ( annotation , ( BaseModel , Mapping , Sequence , tuple , set , frozenset , deque )) or is_dataclass (
2343- annotation
2344- )
2341+ return _lenient_issubclass (
2342+ annotation , ( BaseModel , Mapping , Sequence , tuple , set , frozenset , deque )
2343+ ) or is_dataclass ( annotation )
23452344
23462345
23472346def _union_is_complex (annotation : type [Any ] | None , metadata : list [Any ]) -> bool :
@@ -2365,22 +2364,23 @@ def _annotation_contains_types(
23652364
23662365
23672366def _strip_annotated (annotation : Any ) -> Any :
2368- while get_origin (annotation ) == Annotated :
2369- annotation = get_args (annotation )[0 ]
2370- return annotation
2367+ if typing_objects .is_annotated (get_origin (annotation )):
2368+ return annotation .__origin__
2369+ else :
2370+ return annotation
23712371
23722372
23732373def _annotation_enum_val_to_name (annotation : type [Any ] | None , value : Any ) -> Optional [str ]:
23742374 for type_ in (annotation , get_origin (annotation ), * get_args (annotation )):
2375- if lenient_issubclass (type_ , Enum ):
2375+ if _lenient_issubclass (type_ , Enum ):
23762376 if value in tuple (val .value for val in type_ ):
23772377 return type_ (value ).name
23782378 return None
23792379
23802380
23812381def _annotation_enum_name_to_val (annotation : type [Any ] | None , name : Any ) -> Any :
23822382 for type_ in (annotation , get_origin (annotation ), * get_args (annotation )):
2383- if lenient_issubclass (type_ , Enum ):
2383+ if _lenient_issubclass (type_ , Enum ):
23842384 if name in tuple (val .name for val in type_ ):
23852385 return type_ [name ]
23862386 return None
0 commit comments