Skip to content

Commit 57bf7e6

Browse files
committed
Fix handling of type-var syntax and types.GenericAlias
1 parent a46a2c6 commit 57bf7e6

File tree

3 files changed

+281
-7
lines changed

3 files changed

+281
-7
lines changed

src/msgspec/_core.c

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ typedef struct {
504504
PyObject *typing_final;
505505
PyObject *typing_generic;
506506
PyObject *typing_generic_alias;
507+
PyObject *types_generic_alias;
507508
PyObject *typing_annotated_alias;
508509
PyObject *concrete_types;
509510
PyObject *get_type_hints;
@@ -4948,6 +4949,24 @@ is_dataclass_or_attrs_class(TypeNodeCollectState *state, PyObject *t) {
49484949
);
49494950
}
49504951

4952+
static MS_INLINE PyObject*
4953+
convert_types_generic_alias(TypeNodeCollectState *state, PyObject *obj, PyObject *origin, PyObject *args) {
4954+
// if 'obj' is a 'types.GenericAlias', convert it into a 'typing._GenericAlias', so
4955+
// we can cache type info on it. 'types.GenericAlias' has __slots__, so caching on
4956+
// it directly does not work.
4957+
// it's unlikely to hit this case, as it will mostly occur when subclassing a
4958+
// built-in container generic, such as 'collections.abc.Mapping'
4959+
4960+
if (MS_UNLIKELY(Py_TYPE(obj) == (PyTypeObject *)state->mod->types_generic_alias)) {
4961+
PyObject *genericAliasArgsList = Py_BuildValue("OO", origin, args);
4962+
4963+
PyObject *newGenericAlias = PyObject_CallObject(state->mod->typing_generic_alias, genericAliasArgsList);
4964+
Py_DECREF(genericAliasArgsList);
4965+
return newGenericAlias;
4966+
}
4967+
return obj;
4968+
}
4969+
49514970
static int
49524971
typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) {
49534972
int out = 0;
@@ -5029,7 +5048,7 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) {
50295048
ms_is_struct_cls(t) ||
50305049
(origin != NULL && ms_is_struct_cls(origin))
50315050
) {
5032-
out = typenode_collect_struct(state, t);
5051+
out = typenode_collect_struct(state, convert_types_generic_alias(state, t, origin, args));
50335052
}
50345053
else if (Py_TYPE(t) == state->mod->EnumMetaType) {
50355054
out = typenode_collect_enum(state, t);
@@ -5137,7 +5156,7 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) {
51375156
is_dataclass_or_attrs_class(state, t) ||
51385157
(origin != NULL && is_dataclass_or_attrs_class(state, origin))
51395158
) {
5140-
out = typenode_collect_dataclass(state, t);
5159+
out = typenode_collect_dataclass(state, convert_types_generic_alias(state, t, origin, args));
51415160
}
51425161
else {
51435162
if (origin != NULL) {
@@ -22291,6 +22310,7 @@ msgspec_clear(PyObject *m)
2229122310
Py_CLEAR(st->typing_final);
2229222311
Py_CLEAR(st->typing_generic);
2229322312
Py_CLEAR(st->typing_generic_alias);
22313+
Py_CLEAR(st->types_generic_alias);
2229422314
Py_CLEAR(st->typing_annotated_alias);
2229522315
Py_CLEAR(st->concrete_types);
2229622316
Py_CLEAR(st->get_type_hints);
@@ -22365,6 +22385,7 @@ msgspec_traverse(PyObject *m, visitproc visit, void *arg)
2236522385
Py_VISIT(st->typing_final);
2236622386
Py_VISIT(st->typing_generic);
2236722387
Py_VISIT(st->typing_generic_alias);
22388+
Py_VISIT(st->types_generic_alias);
2236822389
Py_VISIT(st->typing_annotated_alias);
2236922390
Py_VISIT(st->concrete_types);
2237022391
Py_VISIT(st->get_type_hints);
@@ -22598,6 +22619,7 @@ PyInit__core(void)
2259822619
temp_module = PyImport_ImportModule("types");
2259922620
if (temp_module == NULL) return NULL;
2260022621
SET_REF(types_uniontype, "UnionType");
22622+
SET_REF(types_generic_alias, "GenericAlias");
2260122623
Py_DECREF(temp_module);
2260222624
#endif
2260322625

src/msgspec/_utils.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# type: ignore
22
import collections
33
import sys
4+
import types
45
import typing
56
from typing import _AnnotatedAlias # noqa: F401
67

@@ -22,6 +23,8 @@ def get_type_hints(obj):
2223
return _get_type_hints(obj, include_extras=True)
2324

2425

26+
PY_31PLUS = sys.version_info >= (3, 12)
27+
2528
# The `is_class` argument was new in 3.11, but was backported to 3.9 and 3.10.
2629
# It's _likely_ to be available for 3.9/3.10, but may not be. Easiest way to
2730
# check is to try it and see. This check can be removed when we drop support
@@ -110,13 +113,23 @@ def inner(c, scope):
110113
cls = c
111114
new_scope = {}
112115
else:
113-
cls = getattr(c, "__origin__", None)
116+
cls = typing.get_origin(c)
114117
if cls in (None, object, typing.Generic) or cls in mapping:
115118
return
116-
params = cls.__parameters__
117-
args = tuple(_apply_params(a, scope) for a in c.__args__)
118-
assert len(params) == len(args)
119-
mapping[cls] = new_scope = dict(zip(params, args))
119+
120+
# it's a built-in generic that has unresolved type vars. in this case,
121+
# parameters and args are stored on the generic, not the __origin__
122+
if isinstance(c, types.GenericAlias) or (
123+
isinstance(c, typing._GenericAlias)
124+
and not hasattr(cls, "__parameters__")
125+
):
126+
new_scope = dict(zip(c.__parameters__, typing.get_args(c)))
127+
else:
128+
params = cls.__parameters__
129+
args = tuple(_apply_params(a, scope) for a in typing.get_args(c))
130+
assert len(params) == len(args)
131+
new_scope = dict(zip(params, args))
132+
mapping[cls] = new_scope
120133

121134
if issubclass(cls, typing.Generic):
122135
bases = getattr(cls, "__orig_bases__", cls.__bases__)
@@ -154,6 +167,11 @@ def get_class_annotations(obj):
154167

155168
mapping = typevar_mappings.get(cls)
156169
cls_locals = dict(vars(cls))
170+
171+
if PY_31PLUS:
172+
# resolve type parameters (e.g. class Foo[T]: pass)
173+
cls_locals.update({p.__name__: p for p in cls.__type_params__})
174+
157175
cls_globals = getattr(sys.modules.get(cls.__module__, None), "__dict__", {})
158176

159177
ann = _get_class_annotations(cls)

tests/unit/test_common.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,6 +1425,35 @@ class Ex(Struct, Generic[T], array_like=array_like):
14251425
with pytest.raises(ValidationError, match="Expected `str`, got `int`"):
14261426
proto.decode(msg, type=Ex[str])
14271427

1428+
@py312_plus
1429+
def test_generic_with_typevar_syntax(self, proto):
1430+
source = """
1431+
from msgspec import Struct
1432+
from typing import List
1433+
class Ex[T](Struct):
1434+
x: T
1435+
y: List[T]
1436+
"""
1437+
1438+
with temp_module(source) as mod:
1439+
sol = mod.Ex(1, [1, 2])
1440+
msg = proto.encode(sol)
1441+
1442+
res = proto.decode(msg, type=mod.Ex)
1443+
assert res == sol
1444+
1445+
res = proto.decode(msg, type=mod.Ex[int])
1446+
assert res == sol
1447+
1448+
res = proto.decode(msg, type=mod.Ex[Union[int, str]])
1449+
assert res == sol
1450+
1451+
res = proto.decode(msg, type=mod.Ex[float])
1452+
assert type(res.x) is float
1453+
1454+
with pytest.raises(ValidationError, match="Expected `str`, got `int`"):
1455+
proto.decode(msg, type=mod.Ex[str])
1456+
14281457
@pytest.mark.parametrize("array_like", [False, True])
14291458
def test_recursive_generic_struct(self, proto, array_like):
14301459
source = f"""
@@ -1516,6 +1545,95 @@ def test_unbound_typevars_with_constraints_unsupported(self, proto):
15161545

15171546
assert "Unbound TypeVar `~T` has constraints" in str(rec.value)
15181547

1548+
@pytest.mark.parametrize(
1549+
"future",
1550+
[pytest.param(False, id="no future"), pytest.param(False, id="future")],
1551+
)
1552+
@pytest.mark.parametrize(
1553+
"mapping_type", ["collections.abc.Mapping", "typing.Mapping"]
1554+
)
1555+
def test_inherited_builtin_generic(self, mapping_type: str, future: bool):
1556+
source = f"""
1557+
from msgspec import Struct, StructMeta
1558+
import collections
1559+
import abc
1560+
import typing
1561+
1562+
T = typing.TypeVar("T")
1563+
1564+
class CombinedMeta(StructMeta, abc.ABCMeta):
1565+
pass
1566+
1567+
class Foo({mapping_type}[str, T], Struct, typing.Generic[T], metaclass=CombinedMeta):
1568+
data: dict[str, T]
1569+
1570+
def __getitem__(self, x):
1571+
return self.data[x]
1572+
1573+
def __len__(self):
1574+
return len(self.data)
1575+
1576+
def __iter__(self):
1577+
return iter(self.data)
1578+
"""
1579+
1580+
if future:
1581+
source = "from __future__ import annotations\n" + source
1582+
1583+
with temp_module(source) as mod, pytest.raises(ValidationError):
1584+
msgspec.msgpack.decode(
1585+
msgspec.msgpack.encode(mod.Foo({"x": "foo"})), type=mod.Foo[int]
1586+
)
1587+
1588+
msgspec.msgpack.decode(
1589+
msgspec.msgpack.encode(mod.Foo({"x": 1})), type=mod.Foo[int]
1590+
)
1591+
1592+
@pytest.mark.parametrize(
1593+
"future",
1594+
[pytest.param(False, id="no future"), pytest.param(False, id="future")],
1595+
)
1596+
@pytest.mark.parametrize(
1597+
"mapping_type", ["collections.abc.Mapping", "typing.Mapping"]
1598+
)
1599+
@py312_plus
1600+
def test_inherited_builtin_generic_typevar_syntax(
1601+
self, mapping_type: str, future: bool
1602+
):
1603+
source = f"""
1604+
from msgspec import Struct, StructMeta
1605+
import collections
1606+
import abc
1607+
import typing
1608+
1609+
class CombinedMeta(StructMeta, abc.ABCMeta):
1610+
pass
1611+
1612+
class Foo[T]({mapping_type}[str, T], Struct, metaclass=CombinedMeta):
1613+
data: dict[str, T]
1614+
1615+
def __getitem__(self, x):
1616+
return self.data[x]
1617+
1618+
def __len__(self):
1619+
return len(self.data)
1620+
1621+
def __iter__(self):
1622+
return iter(self.data)
1623+
"""
1624+
1625+
if future:
1626+
source = "from __future__ import annotations\n" + source
1627+
1628+
with temp_module(source) as mod, pytest.raises(ValidationError):
1629+
msgspec.msgpack.decode(
1630+
msgspec.msgpack.encode(mod.Foo({"x": "foo"})), type=mod.Foo[int]
1631+
)
1632+
1633+
msgspec.msgpack.decode(
1634+
msgspec.msgpack.encode(mod.Foo({"x": 1})), type=mod.Foo[int]
1635+
)
1636+
15191637

15201638
class TestStructPostInit:
15211639
@pytest.mark.parametrize("array_like", [False, True])
@@ -1697,6 +1815,39 @@ class Ex(Generic[T]):
16971815
assert "`$.b.a`" in str(rec.value)
16981816
assert "Expected `int`, got `str`" in str(rec.value)
16991817

1818+
@pytest.mark.parametrize("module", ["dataclasses", "attrs"])
1819+
@py312_plus
1820+
def test_typevar_syntax(self, module, proto):
1821+
pytest.importorskip(module)
1822+
if module == "dataclasses":
1823+
import_ = "from dataclasses import dataclass as decorator"
1824+
else:
1825+
import_ = "from attrs import define as decorator"
1826+
1827+
source = f"""
1828+
from __future__ import annotations
1829+
from typing import Union
1830+
from msgspec import Struct
1831+
{import_}
1832+
1833+
@decorator
1834+
class Ex[T]:
1835+
a: T
1836+
b: Union[Ex[T], None]
1837+
"""
1838+
1839+
with temp_module(source) as mod:
1840+
msg = mod.Ex(a=1, b=mod.Ex(a=2, b=None))
1841+
msg2 = mod.Ex(a=1, b=mod.Ex(a="bad", b=None))
1842+
assert proto.decode(proto.encode(msg), type=mod.Ex) == msg
1843+
assert proto.decode(proto.encode(msg2), type=mod.Ex) == msg2
1844+
assert proto.decode(proto.encode(msg), type=mod.Ex[int]) == msg
1845+
1846+
with pytest.raises(ValidationError) as rec:
1847+
proto.decode(proto.encode(msg2), type=mod.Ex[int])
1848+
assert "`$.b.a`" in str(rec.value)
1849+
assert "Expected `int`, got `str`" in str(rec.value)
1850+
17001851
def test_unbound_typevars_use_bound_if_set(self, proto):
17011852
T = TypeVar("T", bound=Union[int, str])
17021853

@@ -1719,6 +1870,89 @@ def test_unbound_typevars_with_constraints_unsupported(self, proto):
17191870

17201871
assert "Unbound TypeVar `~T` has constraints" in str(rec.value)
17211872

1873+
@pytest.mark.parametrize(
1874+
"future",
1875+
[pytest.param(False, id="no future"), pytest.param(False, id="future")],
1876+
)
1877+
@pytest.mark.parametrize(
1878+
"mapping_type", ["collections.abc.Mapping", "typing.Mapping"]
1879+
)
1880+
def test_inherited_builtin_generic(self, mapping_type: str, future: bool):
1881+
source = f"""
1882+
import typing
1883+
import dataclasses
1884+
import collections
1885+
1886+
T = typing.TypeVar("T")
1887+
1888+
@dataclasses.dataclass
1889+
class Foo(typing.Generic[T], {mapping_type}[str, T]):
1890+
data: dict[str, T]
1891+
1892+
def __getitem__(self, x):
1893+
return self.data[x]
1894+
1895+
def __len__(self):
1896+
return len(self.data)
1897+
1898+
def __iter__(self):
1899+
return iter(self.data)
1900+
"""
1901+
1902+
if future:
1903+
source = "from __future__ import annotations\n" + source
1904+
1905+
with temp_module(source) as mod, pytest.raises(ValidationError):
1906+
msgspec.msgpack.decode(
1907+
msgspec.msgpack.encode(mod.Foo({"x": "foo"})), type=mod.Foo[int]
1908+
)
1909+
1910+
msgspec.msgpack.decode(
1911+
msgspec.msgpack.encode(mod.Foo({"x": 1})), type=mod.Foo[int]
1912+
)
1913+
1914+
@pytest.mark.parametrize(
1915+
"future",
1916+
[pytest.param(False, id="no future"), pytest.param(False, id="future")],
1917+
)
1918+
@pytest.mark.parametrize(
1919+
"mapping_type", ["collections.abc.Mapping", "typing.Mapping"]
1920+
)
1921+
@py312_plus
1922+
def test_inherited_builtin_generic_typevar_syntax(
1923+
self, mapping_type: str, future: bool
1924+
):
1925+
source = f"""
1926+
import dataclasses
1927+
import collections
1928+
import typing
1929+
1930+
@dataclasses.dataclass
1931+
class Foo[T]({mapping_type}[str, T]):
1932+
data: dict[str, T]
1933+
1934+
def __getitem__(self, x):
1935+
return self.data[x]
1936+
1937+
def __len__(self):
1938+
return len(self.data)
1939+
1940+
def __iter__(self):
1941+
return iter(self.data)
1942+
"""
1943+
1944+
if future:
1945+
source = "from __future__ import annotations\n" + source
1946+
1947+
with temp_module(source) as mod, pytest.raises(ValidationError):
1948+
msgspec.msgpack.decode(
1949+
msgspec.msgpack.encode(mod.Foo({"x": "foo"})), type=mod.Foo[int]
1950+
)
1951+
1952+
msgspec.msgpack.decode(
1953+
msgspec.msgpack.encode(mod.Foo({"x": 1})), type=mod.Foo[int]
1954+
)
1955+
17221956

17231957
class TestStructOmitDefaults:
17241958
def test_omit_defaults(self, proto):

0 commit comments

Comments
 (0)