|
6 | 6 | from datetime import datetime, timezone
|
7 | 7 | from enum import IntEnum
|
8 | 8 | from pathlib import Path
|
9 |
| -from typing import Any, Callable, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar, Union |
| 9 | +from typing import Any, Callable, Dict, Generic, Hashable, List, Optional, Set, Tuple, Type, TypeVar, Union |
10 | 10 |
|
11 | 11 | import pytest
|
12 | 12 | from annotated_types import MinLen
|
13 | 13 | from pydantic import (
|
14 | 14 | AliasChoices,
|
15 | 15 | AliasPath,
|
16 | 16 | BaseModel,
|
| 17 | + Discriminator, |
17 | 18 | Field,
|
18 | 19 | HttpUrl,
|
19 | 20 | Json,
|
20 | 21 | RootModel,
|
21 | 22 | SecretStr,
|
| 23 | + Tag, |
22 | 24 | ValidationError,
|
23 | 25 | )
|
24 | 26 | from pydantic import (
|
25 | 27 | dataclasses as pydantic_dataclasses,
|
26 | 28 | )
|
27 | 29 | from pydantic.fields import FieldInfo
|
28 | 30 | from pytest_mock import MockerFixture
|
29 |
| -from typing_extensions import Annotated |
| 31 | +from typing_extensions import Annotated, Literal |
30 | 32 |
|
31 | 33 | from pydantic_settings import (
|
32 | 34 | BaseSettings,
|
@@ -1674,6 +1676,44 @@ class Cfg(BaseSettings):
|
1674 | 1676 | Cfg()
|
1675 | 1677 |
|
1676 | 1678 |
|
| 1679 | +def test_discriminated_union_with_callable_discriminator(env): |
| 1680 | + class A(BaseModel): |
| 1681 | + x: Literal['a'] = 'a' |
| 1682 | + y: str |
| 1683 | + |
| 1684 | + class B(BaseModel): |
| 1685 | + x: Literal['b'] = 'b' |
| 1686 | + z: str |
| 1687 | + |
| 1688 | + def get_discriminator_value(v: Any) -> Hashable: |
| 1689 | + if isinstance(v, dict): |
| 1690 | + v0 = v.get('x') |
| 1691 | + else: |
| 1692 | + v0 = getattr(v, 'x', None) |
| 1693 | + |
| 1694 | + if v0 == 'a': |
| 1695 | + return 'a' |
| 1696 | + elif v0 == 'b': |
| 1697 | + return 'b' |
| 1698 | + else: |
| 1699 | + return None |
| 1700 | + |
| 1701 | + class Settings(BaseSettings): |
| 1702 | + model_config = SettingsConfigDict(env_nested_delimiter='__') |
| 1703 | + |
| 1704 | + # Discriminated union using a callable discriminator. |
| 1705 | + a_or_b: Annotated[Union[Annotated[A, Tag('a')], Annotated[B, Tag('b')]], Discriminator(get_discriminator_value)] |
| 1706 | + |
| 1707 | + # Set up environment so that the discriminator is 'a'. |
| 1708 | + env.set('a_or_b__x', 'a') |
| 1709 | + env.set('a_or_b__y', 'foo') |
| 1710 | + |
| 1711 | + s = Settings() |
| 1712 | + |
| 1713 | + assert s.a_or_b.x == 'a' |
| 1714 | + assert s.a_or_b.y == 'foo' |
| 1715 | + |
| 1716 | + |
1677 | 1717 | def test_nested_model_case_insensitive(env):
|
1678 | 1718 | class SubSubSub(BaseModel):
|
1679 | 1719 | VaL3: str
|
|
0 commit comments