Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Fixed
^^^^^
- Incorrect instantiation order when instantiation targets share a parent (`#662
<https://github.com/omni-us/jsonargparse/pull/662>`__).
- Pydantic discriminated unions handled incorrectly (`#667
<https://github.com/omni-us/jsonargparse/pull/667>`__).


v4.36.0 (2025-01-17)
Expand Down
5 changes: 4 additions & 1 deletion jsonargparse/_optionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import contextmanager
from importlib.metadata import version
from importlib.util import find_spec
from typing import Optional
from typing import Optional, Union

__all__ = [
"get_config_read_mode",
Expand Down Expand Up @@ -348,10 +348,13 @@ def get_module(value):


def is_annotated_validator(typehint: type) -> bool:
from ._util import get_typehint_origin

return (
pydantic_support > 1
and is_annotated(typehint)
and any(get_module(m) in {"pydantic", "annotated_types"} for m in typehint.__metadata__) # type: ignore[attr-defined]
and get_typehint_origin(typehint.__origin__) != Union # type: ignore[attr-defined]
)


Expand Down
28 changes: 27 additions & 1 deletion jsonargparse_tests/test_dataclass_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import dataclasses
import json
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from typing import Any, Dict, Generic, List, Literal, Optional, Tuple, TypeVar, Union
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -713,6 +713,18 @@ def test_pydantic_annotated_nested_annotated_dataclass_with_default_factory(pars
cfg = parser.parse_args(["--n", "{}"])
assert cfg.n == Namespace(a1=Namespace(a2=1))

class PingTask(pydantic.BaseModel):
type: Literal["ping"] = "ping"
attr: str = ""

class PongTask(pydantic.BaseModel):
type: Literal["pong"] = "pong"

PingPongTask = annotated[
Union[PingTask, PongTask],
pydantic.Field(discriminator="type"),
]


length = "length"
if pydantic_support:
Expand Down Expand Up @@ -806,6 +818,8 @@ def test_subclass(self, parser):
parser.add_argument("--model", type=PydanticSubModel, default=PydanticSubModel(p1="a"))
cfg = parser.parse_args(["--model.p3=0.2"])
assert Namespace(p1="a", p2=3, p3=0.2) == cfg.model
init = parser.instantiate_classes(cfg)
assert isinstance(init.model, PydanticSubModel)

def test_field_default_factory(self, parser):
parser.add_argument("--model", type=PydanticFieldFactory)
Expand All @@ -831,6 +845,18 @@ def test_annotated_field(self, parser):
parser.parse_args(["--model.p1=0"])
ctx.match("model.p1")

@pytest.mark.skipif(not (annotated and pydantic_support > 1), reason="Annotated is required")
def test_field_union_discriminator_dot_syntax(self, parser):
parser.add_argument("--model", type=PingPongTask)
cfg = parser.parse_args(["--model.type=pong"])
assert cfg.model == Namespace(type="pong")
init = parser.instantiate_classes(cfg)
assert isinstance(init.model, PongTask)
cfg = parser.parse_args(["--model.type=ping", "--model.attr=abc"])
assert cfg.model == Namespace(type="ping", attr="abc")
init = parser.instantiate_classes(cfg)
assert isinstance(init.model, PingTask)

@pytest.mark.parametrize(
["valid_value", "invalid_value", "cast", "type_str"],
[
Expand Down
Loading