Skip to content

Commit f786412

Browse files
authored
Fix pydantic discriminated unions handled incorrectly (#667)
1 parent 2a4454a commit f786412

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ Fixed
2525
^^^^^
2626
- Incorrect instantiation order when instantiation targets share a parent (`#662
2727
<https://github.com/omni-us/jsonargparse/pull/662>`__).
28+
- Pydantic discriminated unions handled incorrectly (`#667
29+
<https://github.com/omni-us/jsonargparse/pull/667>`__).
2830

2931

3032
v4.36.0 (2025-01-17)

jsonargparse/_optionals.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from contextlib import contextmanager
66
from importlib.metadata import version
77
from importlib.util import find_spec
8-
from typing import Optional
8+
from typing import Optional, Union
99

1010
__all__ = [
1111
"get_config_read_mode",
@@ -348,10 +348,13 @@ def get_module(value):
348348

349349

350350
def is_annotated_validator(typehint: type) -> bool:
351+
from ._util import get_typehint_origin
352+
351353
return (
352354
pydantic_support > 1
353355
and is_annotated(typehint)
354356
and any(get_module(m) in {"pydantic", "annotated_types"} for m in typehint.__metadata__) # type: ignore[attr-defined]
357+
and get_typehint_origin(typehint.__origin__) != Union # type: ignore[attr-defined]
355358
)
356359

357360

jsonargparse_tests/test_dataclass_like.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import dataclasses
44
import json
5-
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
5+
from typing import Any, Dict, Generic, List, Literal, Optional, Tuple, TypeVar, Union
66
from unittest.mock import patch
77

88
import pytest
@@ -713,6 +713,18 @@ def test_pydantic_annotated_nested_annotated_dataclass_with_default_factory(pars
713713
cfg = parser.parse_args(["--n", "{}"])
714714
assert cfg.n == Namespace(a1=Namespace(a2=1))
715715

716+
class PingTask(pydantic.BaseModel):
717+
type: Literal["ping"] = "ping"
718+
attr: str = ""
719+
720+
class PongTask(pydantic.BaseModel):
721+
type: Literal["pong"] = "pong"
722+
723+
PingPongTask = annotated[
724+
Union[PingTask, PongTask],
725+
pydantic.Field(discriminator="type"),
726+
]
727+
716728

717729
length = "length"
718730
if pydantic_support:
@@ -806,6 +818,8 @@ def test_subclass(self, parser):
806818
parser.add_argument("--model", type=PydanticSubModel, default=PydanticSubModel(p1="a"))
807819
cfg = parser.parse_args(["--model.p3=0.2"])
808820
assert Namespace(p1="a", p2=3, p3=0.2) == cfg.model
821+
init = parser.instantiate_classes(cfg)
822+
assert isinstance(init.model, PydanticSubModel)
809823

810824
def test_field_default_factory(self, parser):
811825
parser.add_argument("--model", type=PydanticFieldFactory)
@@ -831,6 +845,18 @@ def test_annotated_field(self, parser):
831845
parser.parse_args(["--model.p1=0"])
832846
ctx.match("model.p1")
833847

848+
@pytest.mark.skipif(not (annotated and pydantic_support > 1), reason="Annotated is required")
849+
def test_field_union_discriminator_dot_syntax(self, parser):
850+
parser.add_argument("--model", type=PingPongTask)
851+
cfg = parser.parse_args(["--model.type=pong"])
852+
assert cfg.model == Namespace(type="pong")
853+
init = parser.instantiate_classes(cfg)
854+
assert isinstance(init.model, PongTask)
855+
cfg = parser.parse_args(["--model.type=ping", "--model.attr=abc"])
856+
assert cfg.model == Namespace(type="ping", attr="abc")
857+
init = parser.instantiate_classes(cfg)
858+
assert isinstance(init.model, PingTask)
859+
834860
@pytest.mark.parametrize(
835861
["valid_value", "invalid_value", "cast", "type_str"],
836862
[

0 commit comments

Comments
 (0)