Skip to content

Commit 59519a1

Browse files
committed
Support TypedDict field as Dict[str, Any]
1 parent 0c43ff5 commit 59519a1

File tree

3 files changed

+58
-1
lines changed

3 files changed

+58
-1
lines changed

changelog.d/237.change.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `TypedDict` subclass support to fields. These are treated the same as `Dict[str, Any]`.

src/desert/_make.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class User:
6060
import decimal
6161
import enum
6262
import inspect
63+
import sys
6364
import typing as t
6465
import uuid
6566

@@ -305,11 +306,18 @@ def field_for_schema(
305306
field = field_for_schema(newtype_supertype, default=default)
306307

307308
# enumerations
308-
if type(typ) is enum.EnumMeta:
309+
elif type(typ) is enum.EnumMeta:
309310
import marshmallow_enum
310311

311312
field = marshmallow_enum.EnumField(typ, metadata=metadata)
312313

314+
# TypedDict
315+
elif _is_typeddict(typ):
316+
field = marshmallow.fields.Dict(
317+
keys=marshmallow.fields.String,
318+
values=marshmallow.fields.Raw,
319+
)
320+
313321
# Nested dataclasses
314322
forward_reference = getattr(typ, "__forward_arg__", None)
315323

@@ -370,6 +378,21 @@ def _get_field_default(
370378
raise TypeError(field)
371379

372380

381+
def _is_typeddict(typ: t.Any) -> bool:
382+
if typing_inspect.typed_dict_keys(typ) is not None:
383+
return True
384+
385+
# typing_inspect misses some case.
386+
if sys.version_info >= (3, 10):
387+
return t.is_typeddict(typ)
388+
389+
# python>=3.8; <3.10: Reimplement t.is_typeddict
390+
if sys.version_info >= (3, 8):
391+
return isinstance(typ, t._TypedDictMeta)
392+
393+
return False
394+
395+
373396
@attr.frozen
374397
class _DesertSentinel:
375398
pass

tests/test_make.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
import desert
1919

2020

21+
typed_dict_classes: t.List[t.Any] = [typing_extensions.TypedDict]
22+
if sys.version_info >= (3, 8):
23+
typed_dict_classes.append(t.TypedDict)
24+
25+
2126
@attr.frozen(order=False)
2227
class DataclassModule:
2328
"""Implementation of a dataclass module like attr or dataclasses."""
@@ -45,6 +50,13 @@ def dataclass_param(request: _pytest.fixtures.SubRequest) -> DataclassModule:
4550
return module
4651

4752

53+
@pytest.fixture(
54+
params=typed_dict_classes, ids=[x.__module__ for x in typed_dict_classes]
55+
)
56+
def typed_dict_class(request: _pytest.fixtures.SubRequest) -> t.Any:
57+
return request.param
58+
59+
4860
class AssertLoadDumpProtocol(typing_extensions.Protocol):
4961
def __call__(
5062
self, schema: marshmallow.Schema, loaded: t.Any, dumped: t.Dict[t.Any, t.Any]
@@ -437,6 +449,27 @@ class A:
437449
assert_dump_load(schema=schema, loaded=loaded, dumped=dumped)
438450

439451

452+
def test_typed_dict(
453+
module: DataclassModule,
454+
assert_dump_load: AssertLoadDumpProtocol,
455+
typed_dict_class: t.Type[t.Any],
456+
) -> None:
457+
"""Test dataclasses with basic TypedDict support"""
458+
459+
class B(typed_dict_class): # type: ignore[valid-type, misc]
460+
x: int
461+
462+
@module.dataclass
463+
class A:
464+
x: B
465+
466+
schema = desert.schema_class(A)()
467+
dumped = {"x": {"x": 1}}
468+
loaded = A(x={"x": 1}) # type: ignore[call-arg]
469+
470+
assert_dump_load(schema=schema, loaded=loaded, dumped=dumped)
471+
472+
440473
@pytest.mark.xfail(
441474
strict=True,
442475
reason=(

0 commit comments

Comments
 (0)