Skip to content

Commit 37b85b9

Browse files
improve GenericMixin such that it also find bound type variables in parent classes
1 parent 32e7d27 commit 37b85b9

File tree

4 files changed

+68
-34
lines changed

4 files changed

+68
-34
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# Changelog
2+
## Pedantic 2.1.11
3+
- improve `GenericMixin` such that it also find bound type variables in parent classes
4+
25
## Pedantic 2.1.10
36
- added type check support for `functools.partial`
47
- update dependencies

pedantic/mixins/generic_mixin.py

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from types import GenericAlias
2-
from typing import List, Type, TypeVar, Dict, Generic, Any, Optional
1+
from typing import Type, TypeVar, Dict, get_origin, get_args
32

43

54
class GenericMixin:
@@ -11,7 +10,7 @@ class GenericMixin:
1110
>>> T = TypeVar('T')
1211
>>> U = TypeVar('U')
1312
>>> class Foo(Generic[T, U], GenericMixin):
14-
... values: List[T]
13+
... values: list[T]
1514
... value: U
1615
>>> f = Foo[str, int]()
1716
>>> f.type_vars
@@ -24,6 +23,8 @@ def type_var(self) -> Type:
2423
Get the type variable for this class.
2524
Use this for convenience if your class has only one type parameter.
2625
26+
DO NOT call this inside __init__()!
27+
2728
Example:
2829
>>> from typing import Generic, TypeVar
2930
>>> T = TypeVar('T')
@@ -34,7 +35,7 @@ def type_var(self) -> Type:
3435
<class 'float'>
3536
"""
3637

37-
types = self._get_types()
38+
types = self._get_resolved_typevars()
3839
assert len(types) == 1, f'You have multiple type parameters. Please use "type_vars" instead of "type_var".'
3940
return list(types.values())[0] # type: ignore
4041

@@ -43,55 +44,77 @@ def type_vars(self) -> Dict[TypeVar, Type]:
4344
"""
4445
Returns the mapping of type variables to types.
4546
47+
DO NOT call this inside __init__()!
48+
4649
Example:
4750
>>> from typing import Generic, TypeVar
4851
>>> T = TypeVar('T')
4952
>>> U = TypeVar('U')
5053
>>> class Foo(Generic[T, U], GenericMixin):
51-
... values: List[T]
54+
... values: list[T]
5255
... value: U
5356
>>> f = Foo[str, int]()
5457
>>> f.type_vars
5558
{~T: <class 'str'>, ~U: <class 'int'>}
5659
"""
5760

58-
return self._get_types()
61+
return self._get_resolved_typevars()
5962

60-
def _get_types(self) -> Dict[TypeVar, Type]:
63+
def _get_resolved_typevars(self) -> Dict[TypeVar, Type]:
6164
"""
62-
See https://stackoverflow.com/questions/57706180/generict-base-class-how-to-get-type-of-t-from-within-instance/60984681#60984681
65+
Do not call this inside the __init__() method, because at that point the relevant information are not present.
66+
See also https://github.com/python/cpython/issues/90899'
6367
"""
6468

69+
mapping: dict[TypeVar, type] = {}
70+
6571
non_generic_error = AssertionError(
6672
f'{self.class_name} is not a generic class. To make it generic, declare it like: '
6773
f'class {self.class_name}(Generic[T], GenericMixin):...')
6874

6975
if not hasattr(self, '__orig_bases__'):
7076
raise non_generic_error
7177

72-
generic_base = get_generic_base(obj=self)
78+
def collect(base, substitutions: dict[TypeVar, type]) -> None:
79+
origin = get_origin(base)
80+
args = get_args(base)
81+
82+
if origin is None:
83+
return
7384

74-
if not generic_base:
75-
for base in self.__orig_bases__: # type: ignore # (we checked existence above)
76-
if not hasattr(base, '__origin__'):
77-
continue
85+
params = getattr(origin, '__parameters__', ())
86+
resolved = {}
87+
for param, arg in zip(params, args):
88+
resolved_arg = substitutions.get(arg, arg) if isinstance(arg, TypeVar) else arg
89+
mapping[param] = resolved_arg
90+
resolved[param] = resolved_arg
7891

79-
generic_base = get_generic_base(base.__origin__)
92+
for super_base in getattr(origin, '__orig_bases__', []):
93+
collect(super_base, resolved)
8094

81-
if generic_base:
82-
types = base.__args__
83-
break
95+
# Prefer __orig_class__ if available
96+
cls = getattr(self, '__orig_class__', None)
97+
if cls is not None:
98+
collect(base=cls, substitutions={})
8499
else:
85-
if not hasattr(self, '__orig_class__'):
86-
raise AssertionError(
87-
f'You need to instantiate this class with type parameters! Example: {self.class_name}[int]()\n'
88-
f'Also make sure that you do not call this in the __init__() method of your class! '
89-
f'See also https://github.com/python/cpython/issues/90899')
100+
for base in getattr(self.__class__, '__orig_bases__', []):
101+
collect(base=base, substitutions={})
90102

91-
types = self.__orig_class__.__args__ # type: ignore
103+
# Extra safety: ensure all declared typevars are resolved
104+
all_params = set()
105+
for cls in self.__class__.__mro__:
106+
all_params.update(getattr(cls, '__parameters__', ()))
92107

93-
type_vars = generic_base.__args__
94-
return {v: t for v, t in zip(type_vars, types)}
108+
unresolved = {param for param in all_params if param not in mapping or isinstance(mapping[param], TypeVar)}
109+
if unresolved:
110+
raise AssertionError(
111+
f'You need to instantiate this class with type parameters! Example: {self.class_name}[int]()\n'
112+
f'Also make sure that you do not call this in the __init__() method of your class!\n'
113+
f'Unresolved type variables: {unresolved}\n'
114+
f'See also https://github.com/python/cpython/issues/90899'
115+
)
116+
117+
return mapping
95118

96119
@property
97120
def class_name(self) -> str:
@@ -100,14 +123,6 @@ def class_name(self) -> str:
100123
return type(self).__name__
101124

102125

103-
def get_generic_base(obj: Any) -> Optional[GenericAlias]:
104-
generic_bases = [c for c in obj.__orig_bases__ if hasattr(c, '__origin__') and c.__origin__ == Generic]
105-
106-
if generic_bases:
107-
return generic_bases[0] # this is safe because a class can have at most one "Generic" superclass
108-
109-
return None
110-
111126
if __name__ == '__main__':
112127
import doctest
113128
doctest.testmod(verbose=False, optionflags=doctest.ELLIPSIS)

pedantic/tests/test_generic_mixin.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,19 @@ class MyClass(MyMixin, Gen[int]):
103103

104104
foo = MyClass(value=4)
105105
assert foo.get_type() == {T: int}
106+
107+
def test_resolved_type_var_inheritance(self):
108+
class Foo(Generic[T]): ...
109+
110+
class Bar(Foo[int], Generic[U], GenericMixin): ...
111+
112+
bar = Bar[str]()
113+
assert bar.type_vars == {T: int, U: str}
114+
115+
def test_resolved_type_var_inheritance_2(self):
116+
class Foo(Generic[T], GenericMixin): ...
117+
118+
class Bar(Foo[int], Generic[U]): ...
119+
120+
bar = Bar[str]()
121+
assert bar.type_vars == {T: int, U: str}

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def get_content_from_readme(file_name: str = 'README.md') -> str:
1515

1616
setup(
1717
name="pedantic",
18-
version="2.1.10",
18+
version="2.1.11",
1919
python_requires='>=3.11.0',
2020
packages=find_packages(),
2121
install_requires=[],

0 commit comments

Comments
 (0)