Skip to content

Commit 0ebda01

Browse files
author
Diego Argueta
committed
Add support for type-variant generics
1 parent 4f5b3f1 commit 0ebda01

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

src/desert/_make.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class User:
5454
Schema: t.ClassVar[Type[Schema]] = Schema # For the type checker
5555
"""
5656

57+
import collections.abc as coll_abc
5758
import dataclasses
5859
import datetime
5960
import decimal
@@ -244,7 +245,7 @@ def field_for_schema(
244245
if origin:
245246
arguments = typing_inspect.get_args(typ, True)
246247

247-
if origin in (list, t.List):
248+
if origin in (list, t.List, coll_abc.Sequence, coll_abc.MutableSequence):
248249
field = marshmallow.fields.List(field_for_schema(arguments[0]))
249250

250251
if origin in (tuple, t.Tuple) and Ellipsis not in arguments:
@@ -256,7 +257,7 @@ def field_for_schema(
256257
field = VariadicTuple(
257258
field_for_schema(only(arg for arg in arguments if arg != Ellipsis))
258259
)
259-
elif origin in (dict, t.Dict):
260+
elif origin in (dict, t.Dict, coll_abc.Mapping, coll_abc.MutableMapping):
260261
field = marshmallow.fields.Dict(
261262
keys=field_for_schema(arguments[0]),
262263
values=field_for_schema(arguments[1]),

tests/test_make.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -163,24 +163,22 @@ class A:
163163
assert data == A(1) # type: ignore[call-arg]
164164

165165

166-
def test_list(module: DataclassModule) -> None:
166+
@pytest.mark.parametrize("annotation_class", (t.List, t.Sequence, t.MutableSequence))
167+
def test_list(module: DataclassModule, annotation_class: type) -> None:
167168
"""Build a generic list *without* setting a factory on the dataclass."""
168-
169-
@module.dataclass
170-
class A:
171-
y: t.List[int]
169+
klass = type("A", (object,), {"__annotations__": {"y": annotation_class[int]}})
170+
A = module.dataclass(klass)
172171

173172
schema = desert.schema_class(A)()
174173
data = schema.load({"y": [1]})
175174
assert data == A([1]) # type: ignore[call-arg]
176175

177176

178-
def test_dict(module: DataclassModule) -> None:
177+
@pytest.mark.parametrize("annotation_class", (t.Dict, t.Mapping, t.MutableMapping))
178+
def test_dict(module: DataclassModule, annotation_class: type) -> None:
179179
"""Build a dict without setting a factory on the dataclass."""
180-
181-
@module.dataclass
182-
class A:
183-
y: t.Dict[int, int]
180+
klass = type("A", (object,), {"__annotations__": {"y": annotation_class[int, int]}})
181+
A = module.dataclass(klass)
184182

185183
schema = desert.schema_class(A)()
186184
data = schema.load({"y": {1: 2, 3: 4}})
@@ -527,15 +525,12 @@ class A:
527525
desert.schema_class(A)
528526

529527

530-
@pytest.mark.skipif(
531-
sys.version_info[:2] <= (3, 6), reason="3.6 has isinstance(t.Sequence[int], type)."
532-
)
533528
def test_raise_unknown_generic(module: DataclassModule) -> None:
534529
"""Raise UnknownType for unknown generics."""
535530

536531
@module.dataclass
537532
class A:
538-
x: t.Sequence[int]
533+
x: t.Iterable[int]
539534

540535
with pytest.raises(desert.exceptions.UnknownType):
541536
desert.schema_class(A)

0 commit comments

Comments
 (0)