Skip to content

Commit 8c5a45e

Browse files
authored
Fix an issue when inner types of a discriminated union with a callable discriminator were not correctly identified as complex. (#285)
1 parent 09d1009 commit 8c5a45e

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

pydantic_settings/sources.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union
1717
from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass
1818
from pydantic.fields import FieldInfo
19-
from typing_extensions import get_args, get_origin
19+
from typing_extensions import _AnnotatedAlias, get_args, get_origin
2020

2121
from pydantic_settings.utils import path_type_label
2222

@@ -913,6 +913,11 @@ def read_env_file(
913913
def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
914914
if any(isinstance(md, Json) for md in metadata): # type: ignore[misc]
915915
return False
916+
# Check if annotation is of the form Annotated[type, metadata].
917+
if isinstance(annotation, _AnnotatedAlias):
918+
# Return result of recursive call on inner type.
919+
inner, meta = get_args(annotation)
920+
return _annotation_is_complex(inner, [meta])
916921
origin = get_origin(annotation)
917922
return (
918923
_annotation_is_complex_inner(annotation)

tests/test_settings.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,29 @@
66
from datetime import datetime, timezone
77
from enum import IntEnum
88
from pathlib import Path
9-
from typing import Any, Callable, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar, Union
9+
from typing import Any, Callable, Dict, Generic, Hashable, List, Optional, Set, Tuple, Type, TypeVar, Union
1010

1111
import pytest
1212
from annotated_types import MinLen
1313
from pydantic import (
1414
AliasChoices,
1515
AliasPath,
1616
BaseModel,
17+
Discriminator,
1718
Field,
1819
HttpUrl,
1920
Json,
2021
RootModel,
2122
SecretStr,
23+
Tag,
2224
ValidationError,
2325
)
2426
from pydantic import (
2527
dataclasses as pydantic_dataclasses,
2628
)
2729
from pydantic.fields import FieldInfo
2830
from pytest_mock import MockerFixture
29-
from typing_extensions import Annotated
31+
from typing_extensions import Annotated, Literal
3032

3133
from pydantic_settings import (
3234
BaseSettings,
@@ -1674,6 +1676,44 @@ class Cfg(BaseSettings):
16741676
Cfg()
16751677

16761678

1679+
def test_discriminated_union_with_callable_discriminator(env):
1680+
class A(BaseModel):
1681+
x: Literal['a'] = 'a'
1682+
y: str
1683+
1684+
class B(BaseModel):
1685+
x: Literal['b'] = 'b'
1686+
z: str
1687+
1688+
def get_discriminator_value(v: Any) -> Hashable:
1689+
if isinstance(v, dict):
1690+
v0 = v.get('x')
1691+
else:
1692+
v0 = getattr(v, 'x', None)
1693+
1694+
if v0 == 'a':
1695+
return 'a'
1696+
elif v0 == 'b':
1697+
return 'b'
1698+
else:
1699+
return None
1700+
1701+
class Settings(BaseSettings):
1702+
model_config = SettingsConfigDict(env_nested_delimiter='__')
1703+
1704+
# Discriminated union using a callable discriminator.
1705+
a_or_b: Annotated[Union[Annotated[A, Tag('a')], Annotated[B, Tag('b')]], Discriminator(get_discriminator_value)]
1706+
1707+
# Set up environment so that the discriminator is 'a'.
1708+
env.set('a_or_b__x', 'a')
1709+
env.set('a_or_b__y', 'foo')
1710+
1711+
s = Settings()
1712+
1713+
assert s.a_or_b.x == 'a'
1714+
assert s.a_or_b.y == 'foo'
1715+
1716+
16771717
def test_nested_model_case_insensitive(env):
16781718
class SubSubSub(BaseModel):
16791719
VaL3: str

0 commit comments

Comments
 (0)