Skip to content

Commit 6f5000c

Browse files
committed
fix: optional union types being marked as required
1 parent d60de24 commit 6f5000c

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

graphene_pydantic/converters.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import typing as T
99
import uuid
1010
from types import UnionType
11+
from typing import get_origin
1112

1213
import graphene
1314
from graphene import (
@@ -113,9 +114,14 @@ def convert_pydantic_field(
113114
to the generated Graphene data model type.
114115
"""
115116
declared_type = getattr(field, "annotation", None)
117+
116118
# Convert Python 11 UnionType to T.Union
117-
if isinstance(declared_type, UnionType):
119+
is_union_type = (
120+
get_origin(declared_type) is T.Union or get_origin(declared_type) is UnionType
121+
)
122+
if is_union_type:
118123
declared_type = T.Union[declared_type.__args__]
124+
119125
field_kwargs.setdefault(
120126
"type" if GRAPHENE2 else "type_",
121127
convert_pydantic_type(
@@ -128,6 +134,7 @@ def convert_pydantic_field(
128134
or (
129135
type(field.default) is not PydanticUndefined
130136
and getattr(declared_type, "_name", "") != "Optional"
137+
and not is_union_type
131138
),
132139
)
133140
field_kwargs.setdefault(

tests/test_converters.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys
55
import typing as T
66
import uuid
7+
from typing import Optional
78

89
import graphene
910
import graphene.types
@@ -70,10 +71,26 @@ def test_builtin_scalars(input, expected):
7071
assert field.default_value == input[1]
7172

7273

73-
def test_union():
74+
def test_union_optional():
7475
field = _convert_field_from_spec("attr", (T.Union[int, float, str], 5.0))
75-
assert issubclass(field.type.of_type, graphene.Union)
76+
assert issubclass(field.type, graphene.Union)
7677
assert field.default_value == 5.0
78+
assert field.type.__name__.startswith("UnionOf")
79+
80+
81+
@pytest.mark.parametrize(
82+
"input",
83+
[
84+
(T.Union[int, float, str], ...),
85+
(T.Union[int, float, str, None], ...),
86+
(Optional[T.Union[int, float, str]], ...),
87+
(Optional[T.Union[int, float, None]], ...),
88+
],
89+
)
90+
def test_union(input):
91+
field = _convert_field_from_spec("attr", input)
92+
assert isinstance(field.type, graphene.NonNull)
93+
assert field.default_value == None
7794
assert field.type.of_type.__name__.startswith("UnionOf")
7895

7996

@@ -95,6 +112,27 @@ def test_literal_singleton():
95112
assert field.default_value == "literal1"
96113
assert field.type.of_type == graphene.String
97114

115+
def test_union_pipe_optional():
116+
field = _convert_field_from_spec("attr", (int | float | str, 5.0))
117+
assert issubclass(field.type, graphene.Union)
118+
assert field.default_value == 5.0
119+
assert field.type.__name__.startswith("UnionOf")
120+
121+
@pytest.mark.parametrize(
122+
"input",
123+
[
124+
(int | float | str, ...),
125+
(int | float | str | None, ...),
126+
(Optional[int | float | str], ...),
127+
(Optional[int | float | None], ...),
128+
],
129+
)
130+
def test_union_pipe(input):
131+
field = _convert_field_from_spec("attr", input)
132+
assert isinstance(field.type, graphene.NonNull)
133+
assert field.default_value == None
134+
assert field.type.of_type.__name__.startswith("UnionOf")
135+
98136

99137
def test_mapping():
100138
with pytest.raises(ConversionError) as exc:

0 commit comments

Comments
 (0)