Skip to content

Commit 6c5c54e

Browse files
authored
Merge pull request #164 from dargueta/generic-types
2 parents 4f5b3f1 + 1380be9 commit 6c5c54e

File tree

3 files changed

+65
-15
lines changed

3 files changed

+65
-15
lines changed

changelog.d/140.change.rst

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
It is now possible to use `type-variant generics`_ in your dataclasses, such as ``Sequence``
2+
or ``MutableSequence`` instead of ``List``, ``Mapping`` instead of ``Dict``, etc.
3+
4+
This allows you to hide implementation details from users of your dataclasses. If a field
5+
in your dataclass works just as fine with a tuple as a list, you no longer need to force
6+
your users to pass in a ``list`` just to satisfy type checkers.
7+
8+
For example, by using ``Mapping`` or ``MutableMapping``, users can pass ``OrderedDict`` to
9+
a ``Dict`` attribute without MyPy complaining.
10+
11+
.. code-block:: python
12+
13+
@dataclass
14+
class OldWay:
15+
str_list: List[str]
16+
num_map: Dict[str, float]
17+
18+
19+
# MyPy will reject this even though Marshmallow works just fine. If you use
20+
# type-variant generics, MyPy will accept this code.
21+
instance = OldClass([], collections.ChainMap(MY_DEFAULTS))
22+
23+
24+
@dataclass
25+
class NewWay:
26+
str_list: List[str] # Type-invariants still work
27+
num_map: MutableMapping[str, float] # Now generics do too
28+
29+
30+
.. _type-variant generics: https://mypy.readthedocs.io/en/stable/generics.html

src/desert/_make.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class User:
5454
Schema: t.ClassVar[Type[Schema]] = Schema # For the type checker
5555
"""
5656

57+
import collections.abc
5758
import dataclasses
5859
import datetime
5960
import decimal
@@ -244,7 +245,14 @@ def field_for_schema(
244245
if origin:
245246
arguments = typing_inspect.get_args(typ, True)
246247

247-
if origin in (list, t.List):
248+
if origin in (
249+
list,
250+
t.List,
251+
t.Sequence,
252+
t.MutableSequence,
253+
collections.abc.Sequence,
254+
collections.abc.MutableSequence,
255+
):
248256
field = marshmallow.fields.List(field_for_schema(arguments[0]))
249257

250258
if origin in (tuple, t.Tuple) and Ellipsis not in arguments:
@@ -256,7 +264,14 @@ def field_for_schema(
256264
field = VariadicTuple(
257265
field_for_schema(only(arg for arg in arguments if arg != Ellipsis))
258266
)
259-
elif origin in (dict, t.Dict):
267+
elif origin in (
268+
dict,
269+
t.Dict,
270+
t.Mapping,
271+
t.MutableMapping,
272+
collections.abc.Mapping,
273+
collections.abc.MutableMapping,
274+
):
260275
field = marshmallow.fields.Dict(
261276
keys=field_for_schema(arguments[0]),
262277
values=field_for_schema(arguments[1]),

tests/test_make.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -163,29 +163,27 @@ class A:
163163
assert data == A(1) # type: ignore[call-arg]
164164

165165

166-
def test_list(module: DataclassModule) -> None:
166+
@pytest.mark.parametrize("annotation_class", (t.List, t.Sequence, t.MutableSequence))
167+
def test_list(module: DataclassModule, annotation_class: type) -> None:
167168
"""Build a generic list *without* setting a factory on the dataclass."""
168-
169-
@module.dataclass
170-
class A:
171-
y: t.List[int]
169+
cls = type("A", (object,), {"__annotations__": {"y": annotation_class[int]}}) # type: ignore[index]
170+
A = module.dataclass(cls)
172171

173172
schema = desert.schema_class(A)()
174173
data = schema.load({"y": [1]})
175-
assert data == A([1]) # type: ignore[call-arg]
174+
assert data == A([1])
176175

177176

178-
def test_dict(module: DataclassModule) -> None:
177+
@pytest.mark.parametrize("annotation_class", (t.Dict, t.Mapping, t.MutableMapping))
178+
def test_dict(module: DataclassModule, annotation_class: type) -> None:
179179
"""Build a dict without setting a factory on the dataclass."""
180-
181-
@module.dataclass
182-
class A:
183-
y: t.Dict[int, int]
180+
cls = type("A", (object,), {"__annotations__": {"y": annotation_class[int, int]}}) # type: ignore[index]
181+
A = module.dataclass(cls)
184182

185183
schema = desert.schema_class(A)()
186184
data = schema.load({"y": {1: 2, 3: 4}})
187185

188-
assert data == A({1: 2, 3: 4}) # type: ignore[call-arg]
186+
assert data == A({1: 2, 3: 4})
189187

190188

191189
def test_nested(module: DataclassModule) -> None:
@@ -527,6 +525,13 @@ class A:
527525
desert.schema_class(A)
528526

529527

528+
T = t.TypeVar("T")
529+
530+
531+
class UnknownGeneric(t.Generic[T]):
532+
pass
533+
534+
530535
@pytest.mark.skipif(
531536
sys.version_info[:2] <= (3, 6), reason="3.6 has isinstance(t.Sequence[int], type)."
532537
)
@@ -535,7 +540,7 @@ def test_raise_unknown_generic(module: DataclassModule) -> None:
535540

536541
@module.dataclass
537542
class A:
538-
x: t.Sequence[int]
543+
x: UnknownGeneric[int]
539544

540545
with pytest.raises(desert.exceptions.UnknownType):
541546
desert.schema_class(A)

0 commit comments

Comments
 (0)