|
11 | 11 |
|
12 | 12 | from dotenv import dotenv_values
|
13 | 13 | from pydantic import AliasChoices, AliasPath, BaseModel, Json, TypeAdapter
|
14 |
| -from pydantic._internal._typing_extra import origin_is_union |
15 |
| -from pydantic._internal._utils import deep_update, lenient_issubclass |
| 14 | +from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union |
| 15 | +from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass |
16 | 16 | from pydantic.fields import FieldInfo
|
17 | 17 | from typing_extensions import get_args, get_origin
|
18 | 18 |
|
@@ -188,6 +188,8 @@ def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[s
|
188 | 188 | )
|
189 | 189 | else: # string validation alias
|
190 | 190 | field_info.append((v_alias, self._apply_case_sensitive(v_alias), False))
|
| 191 | + elif origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata): |
| 192 | + field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), True)) |
191 | 193 | else:
|
192 | 194 | field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False))
|
193 | 195 |
|
@@ -478,24 +480,21 @@ def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, val
|
478 | 480 | # simplest case, field is not complex, we only need to add the value if it was found
|
479 | 481 | return value
|
480 | 482 |
|
481 |
| - def _union_is_complex(self, annotation: type[Any] | None, metadata: list[Any]) -> bool: |
482 |
| - return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation)) |
483 |
| - |
484 | 483 | def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
|
485 | 484 | """
|
486 | 485 | Find out if a field is complex, and if so whether JSON errors should be ignored
|
487 | 486 | """
|
488 | 487 | if self.field_is_complex(field):
|
489 | 488 | allow_parse_failure = False
|
490 |
| - elif origin_is_union(get_origin(field.annotation)) and self._union_is_complex(field.annotation, field.metadata): |
| 489 | + elif origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata): |
491 | 490 | allow_parse_failure = True
|
492 | 491 | else:
|
493 | 492 | return False, False
|
494 | 493 |
|
495 | 494 | return True, allow_parse_failure
|
496 | 495 |
|
497 | 496 | @staticmethod
|
498 |
| - def next_field(field: FieldInfo | None, key: str) -> FieldInfo | None: |
| 497 | + def next_field(field: FieldInfo | Any | None, key: str) -> FieldInfo | None: |
499 | 498 | """
|
500 | 499 | Find the field in a sub model by key(env name)
|
501 | 500 |
|
@@ -524,11 +523,17 @@ class Cfg(BaseSettings):
|
524 | 523 | Returns:
|
525 | 524 | Field if it finds the next field otherwise `None`.
|
526 | 525 | """
|
527 |
| - if not field or origin_is_union(get_origin(field.annotation)): |
528 |
| - # no support for Unions of complex BaseSettings fields |
| 526 | + if not field: |
529 | 527 | return None
|
530 |
| - elif field.annotation and hasattr(field.annotation, 'model_fields') and field.annotation.model_fields.get(key): |
531 |
| - return field.annotation.model_fields[key] |
| 528 | + |
| 529 | + annotation = field.annotation if isinstance(field, FieldInfo) else field |
| 530 | + if origin_is_union(get_origin(annotation)) or isinstance(annotation, WithArgsTypes): |
| 531 | + for type_ in get_args(annotation): |
| 532 | + type_has_key = EnvSettingsSource.next_field(type_, key) |
| 533 | + if type_has_key: |
| 534 | + return type_has_key |
| 535 | + elif is_model_class(annotation) and annotation.model_fields.get(key): |
| 536 | + return annotation.model_fields[key] |
532 | 537 |
|
533 | 538 | return None
|
534 | 539 |
|
@@ -721,3 +726,7 @@ def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool:
|
721 | 726 | return lenient_issubclass(annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)) or is_dataclass(
|
722 | 727 | annotation
|
723 | 728 | )
|
| 729 | + |
| 730 | + |
| 731 | +def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool: |
| 732 | + return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation)) |
0 commit comments