@@ -1425,6 +1425,35 @@ class Ex(Struct, Generic[T], array_like=array_like):
14251425 with pytest .raises (ValidationError , match = "Expected `str`, got `int`" ):
14261426 proto .decode (msg , type = Ex [str ])
14271427
1428+ @py312_plus
1429+ def test_generic_with_typevar_syntax (self , proto ):
1430+ source = """
1431+ from msgspec import Struct
1432+ from typing import List
1433+ class Ex[T](Struct):
1434+ x: T
1435+ y: List[T]
1436+ """
1437+
1438+ with temp_module (source ) as mod :
1439+ sol = mod .Ex (1 , [1 , 2 ])
1440+ msg = proto .encode (sol )
1441+
1442+ res = proto .decode (msg , type = mod .Ex )
1443+ assert res == sol
1444+
1445+ res = proto .decode (msg , type = mod .Ex [int ])
1446+ assert res == sol
1447+
1448+ res = proto .decode (msg , type = mod .Ex [Union [int , str ]])
1449+ assert res == sol
1450+
1451+ res = proto .decode (msg , type = mod .Ex [float ])
1452+ assert type (res .x ) is float
1453+
1454+ with pytest .raises (ValidationError , match = "Expected `str`, got `int`" ):
1455+ proto .decode (msg , type = mod .Ex [str ])
1456+
14281457 @pytest .mark .parametrize ("array_like" , [False , True ])
14291458 def test_recursive_generic_struct (self , proto , array_like ):
14301459 source = f"""
@@ -1516,6 +1545,95 @@ def test_unbound_typevars_with_constraints_unsupported(self, proto):
15161545
15171546 assert "Unbound TypeVar `~T` has constraints" in str (rec .value )
15181547
1548+ @pytest .mark .parametrize (
1549+ "future" ,
1550+ [pytest .param (False , id = "no future" ), pytest .param (False , id = "future" )],
1551+ )
1552+ @pytest .mark .parametrize (
1553+ "mapping_type" , ["collections.abc.Mapping" , "typing.Mapping" ]
1554+ )
1555+ def test_inherited_builtin_generic (self , mapping_type : str , future : bool ):
1556+ source = f"""
1557+ from msgspec import Struct, StructMeta
1558+ import collections
1559+ import abc
1560+ import typing
1561+
1562+ T = typing.TypeVar("T")
1563+
1564+ class CombinedMeta(StructMeta, abc.ABCMeta):
1565+ pass
1566+
1567+ class Foo({ mapping_type } [str, T], Struct, typing.Generic[T], metaclass=CombinedMeta):
1568+ data: dict[str, T]
1569+
1570+ def __getitem__(self, x):
1571+ return self.data[x]
1572+
1573+ def __len__(self):
1574+ return len(self.data)
1575+
1576+ def __iter__(self):
1577+ return iter(self.data)
1578+ """
1579+
1580+ if future :
1581+ source = "from __future__ import annotations\n " + source
1582+
1583+ with temp_module (source ) as mod , pytest .raises (ValidationError ):
1584+ msgspec .msgpack .decode (
1585+ msgspec .msgpack .encode (mod .Foo ({"x" : "foo" })), type = mod .Foo [int ]
1586+ )
1587+
1588+ msgspec .msgpack .decode (
1589+ msgspec .msgpack .encode (mod .Foo ({"x" : 1 })), type = mod .Foo [int ]
1590+ )
1591+
1592+ @pytest .mark .parametrize (
1593+ "future" ,
1594+ [pytest .param (False , id = "no future" ), pytest .param (False , id = "future" )],
1595+ )
1596+ @pytest .mark .parametrize (
1597+ "mapping_type" , ["collections.abc.Mapping" , "typing.Mapping" ]
1598+ )
1599+ @py312_plus
1600+ def test_inherited_builtin_generic_typevar_syntax (
1601+ self , mapping_type : str , future : bool
1602+ ):
1603+ source = f"""
1604+ from msgspec import Struct, StructMeta
1605+ import collections
1606+ import abc
1607+ import typing
1608+
1609+ class CombinedMeta(StructMeta, abc.ABCMeta):
1610+ pass
1611+
1612+ class Foo[T]({ mapping_type } [str, T], Struct, metaclass=CombinedMeta):
1613+ data: dict[str, T]
1614+
1615+ def __getitem__(self, x):
1616+ return self.data[x]
1617+
1618+ def __len__(self):
1619+ return len(self.data)
1620+
1621+ def __iter__(self):
1622+ return iter(self.data)
1623+ """
1624+
1625+ if future :
1626+ source = "from __future__ import annotations\n " + source
1627+
1628+ with temp_module (source ) as mod , pytest .raises (ValidationError ):
1629+ msgspec .msgpack .decode (
1630+ msgspec .msgpack .encode (mod .Foo ({"x" : "foo" })), type = mod .Foo [int ]
1631+ )
1632+
1633+ msgspec .msgpack .decode (
1634+ msgspec .msgpack .encode (mod .Foo ({"x" : 1 })), type = mod .Foo [int ]
1635+ )
1636+
15191637
15201638class TestStructPostInit :
15211639 @pytest .mark .parametrize ("array_like" , [False , True ])
@@ -1697,6 +1815,39 @@ class Ex(Generic[T]):
16971815 assert "`$.b.a`" in str (rec .value )
16981816 assert "Expected `int`, got `str`" in str (rec .value )
16991817
1818+ @pytest .mark .parametrize ("module" , ["dataclasses" , "attrs" ])
1819+ @py312_plus
1820+ def test_typevar_syntax (self , module , proto ):
1821+ pytest .importorskip (module )
1822+ if module == "dataclasses" :
1823+ import_ = "from dataclasses import dataclass as decorator"
1824+ else :
1825+ import_ = "from attrs import define as decorator"
1826+
1827+ source = f"""
1828+ from __future__ import annotations
1829+ from typing import Union
1830+ from msgspec import Struct
1831+ { import_ }
1832+
1833+ @decorator
1834+ class Ex[T]:
1835+ a: T
1836+ b: Union[Ex[T], None]
1837+ """
1838+
1839+ with temp_module (source ) as mod :
1840+ msg = mod .Ex (a = 1 , b = mod .Ex (a = 2 , b = None ))
1841+ msg2 = mod .Ex (a = 1 , b = mod .Ex (a = "bad" , b = None ))
1842+ assert proto .decode (proto .encode (msg ), type = mod .Ex ) == msg
1843+ assert proto .decode (proto .encode (msg2 ), type = mod .Ex ) == msg2
1844+ assert proto .decode (proto .encode (msg ), type = mod .Ex [int ]) == msg
1845+
1846+ with pytest .raises (ValidationError ) as rec :
1847+ proto .decode (proto .encode (msg2 ), type = mod .Ex [int ])
1848+ assert "`$.b.a`" in str (rec .value )
1849+ assert "Expected `int`, got `str`" in str (rec .value )
1850+
17001851 def test_unbound_typevars_use_bound_if_set (self , proto ):
17011852 T = TypeVar ("T" , bound = Union [int , str ])
17021853
@@ -1719,6 +1870,89 @@ def test_unbound_typevars_with_constraints_unsupported(self, proto):
17191870
17201871 assert "Unbound TypeVar `~T` has constraints" in str (rec .value )
17211872
1873+ @pytest .mark .parametrize (
1874+ "future" ,
1875+ [pytest .param (False , id = "no future" ), pytest .param (False , id = "future" )],
1876+ )
1877+ @pytest .mark .parametrize (
1878+ "mapping_type" , ["collections.abc.Mapping" , "typing.Mapping" ]
1879+ )
1880+ def test_inherited_builtin_generic (self , mapping_type : str , future : bool ):
1881+ source = f"""
1882+ import typing
1883+ import dataclasses
1884+ import collections
1885+
1886+ T = typing.TypeVar("T")
1887+
1888+ @dataclasses.dataclass
1889+ class Foo(typing.Generic[T], { mapping_type } [str, T]):
1890+ data: dict[str, T]
1891+
1892+ def __getitem__(self, x):
1893+ return self.data[x]
1894+
1895+ def __len__(self):
1896+ return len(self.data)
1897+
1898+ def __iter__(self):
1899+ return iter(self.data)
1900+ """
1901+
1902+ if future :
1903+ source = "from __future__ import annotations\n " + source
1904+
1905+ with temp_module (source ) as mod , pytest .raises (ValidationError ):
1906+ msgspec .msgpack .decode (
1907+ msgspec .msgpack .encode (mod .Foo ({"x" : "foo" })), type = mod .Foo [int ]
1908+ )
1909+
1910+ msgspec .msgpack .decode (
1911+ msgspec .msgpack .encode (mod .Foo ({"x" : 1 })), type = mod .Foo [int ]
1912+ )
1913+
1914+ @pytest .mark .parametrize (
1915+ "future" ,
1916+ [pytest .param (False , id = "no future" ), pytest .param (False , id = "future" )],
1917+ )
1918+ @pytest .mark .parametrize (
1919+ "mapping_type" , ["collections.abc.Mapping" , "typing.Mapping" ]
1920+ )
1921+ @py312_plus
1922+ def test_inherited_builtin_generic_typevar_syntax (
1923+ self , mapping_type : str , future : bool
1924+ ):
1925+ source = f"""
1926+ import dataclasses
1927+ import collections
1928+ import typing
1929+
1930+ @dataclasses.dataclass
1931+ class Foo[T]({ mapping_type } [str, T]):
1932+ data: dict[str, T]
1933+
1934+ def __getitem__(self, x):
1935+ return self.data[x]
1936+
1937+ def __len__(self):
1938+ return len(self.data)
1939+
1940+ def __iter__(self):
1941+ return iter(self.data)
1942+ """
1943+
1944+ if future :
1945+ source = "from __future__ import annotations\n " + source
1946+
1947+ with temp_module (source ) as mod , pytest .raises (ValidationError ):
1948+ msgspec .msgpack .decode (
1949+ msgspec .msgpack .encode (mod .Foo ({"x" : "foo" })), type = mod .Foo [int ]
1950+ )
1951+
1952+ msgspec .msgpack .decode (
1953+ msgspec .msgpack .encode (mod .Foo ({"x" : 1 })), type = mod .Foo [int ]
1954+ )
1955+
17221956
17231957class TestStructOmitDefaults :
17241958 def test_omit_defaults (self , proto ):
0 commit comments