Skip to content

Commit c69bd74

Browse files
committed
@robotoer's changes to support literal types
1 parent fa99d64 commit c69bd74

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

graphene_pydantic/converters.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,10 @@ def convert_generic_python_type(
266266
return convert_union_type(
267267
type_, field, registry, parent_type=parent_type, model=model
268268
)
269+
elif origin == T.Literal:
270+
return convert_literal_type(
271+
type_, field, registry, parent_type=parent_type, model=model
272+
)
269273
elif (
270274
origin
271275
in (
@@ -332,3 +336,37 @@ def convert_union_type(
332336
construct_union_class_name(inner_types), (Union,), {"Meta": internal_meta_cls}
333337
)
334338
return union_cls
339+
340+
def convert_literal_type(
341+
type_: T.Type,
342+
field: ModelField,
343+
registry: Registry,
344+
parent_type: T.Type = None,
345+
model: T.Type[BaseModel] = None,
346+
):
347+
"""
348+
Convert an annotated Python Literal type into a Graphene Scalar or Union of Scalars.
349+
"""
350+
inner_types = type_.__args__
351+
# Here we'll expand the subtypes of this Literal into a corresponding more
352+
# general scalar type.
353+
scalar_types = {
354+
type(x)
355+
for x in inner_types
356+
if x != NONE_TYPE
357+
}
358+
graphene_scalar_types = [
359+
convert_pydantic_type(x, field, registry, parent_type=parent_type, model=model)
360+
for x in scalar_types
361+
]
362+
363+
# If we only have a single type, we don't need to create a union.
364+
if len(graphene_scalar_types) == 1:
365+
return graphene_scalar_types[0]
366+
367+
internal_meta_cls = type("Meta", (), {"types": graphene_scalar_types})
368+
369+
union_cls = type(
370+
construct_union_class_name(scalar_types), (Union,), {"Meta": internal_meta_cls}
371+
)
372+
return union_cls

tests/test_converters.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,19 @@ def test_union():
7474
assert field.default_value == 5.0
7575
assert field.type.__name__.startswith("UnionOf")
7676

77+
def test_literal():
78+
field = _convert_field_from_spec("attr", (T.Literal['literal1', 'literal2', 3], 3))
79+
assert issubclass(field.type, graphene.Union)
80+
assert field.default_value == 3
81+
assert field.type.__name__.startswith("UnionOf")
82+
83+
84+
def test_literal_singleton():
85+
field = _convert_field_from_spec("attr", (T.Literal['literal1'], 'literal1'))
86+
assert issubclass(field.type, graphene.String)
87+
assert field.default_value == 'literal1'
88+
assert field.type == graphene.String
89+
7790

7891
def test_mapping():
7992
with pytest.raises(ConversionError) as exc:

0 commit comments

Comments
 (0)