Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Changelog
## Pedantic 2.1.11
- improve `GenericMixin` such that it also find bound type variables in parent classes

## Pedantic 2.1.10
- added type check support for `functools.partial`
- update dependencies
Expand Down
81 changes: 48 additions & 33 deletions pedantic/mixins/generic_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from types import GenericAlias
from typing import List, Type, TypeVar, Dict, Generic, Any, Optional
from typing import Type, TypeVar, Dict, get_origin, get_args


class GenericMixin:
Expand All @@ -11,7 +10,7 @@ class GenericMixin:
>>> T = TypeVar('T')
>>> U = TypeVar('U')
>>> class Foo(Generic[T, U], GenericMixin):
... values: List[T]
... values: list[T]
... value: U
>>> f = Foo[str, int]()
>>> f.type_vars
Expand All @@ -24,6 +23,8 @@ def type_var(self) -> Type:
Get the type variable for this class.
Use this for convenience if your class has only one type parameter.

DO NOT call this inside __init__()!

Example:
>>> from typing import Generic, TypeVar
>>> T = TypeVar('T')
Expand All @@ -34,7 +35,7 @@ def type_var(self) -> Type:
<class 'float'>
"""

types = self._get_types()
types = self._get_resolved_typevars()
assert len(types) == 1, f'You have multiple type parameters. Please use "type_vars" instead of "type_var".'
return list(types.values())[0] # type: ignore

Expand All @@ -43,55 +44,77 @@ def type_vars(self) -> Dict[TypeVar, Type]:
"""
Returns the mapping of type variables to types.

DO NOT call this inside __init__()!

Example:
>>> from typing import Generic, TypeVar
>>> T = TypeVar('T')
>>> U = TypeVar('U')
>>> class Foo(Generic[T, U], GenericMixin):
... values: List[T]
... values: list[T]
... value: U
>>> f = Foo[str, int]()
>>> f.type_vars
{~T: <class 'str'>, ~U: <class 'int'>}
"""

return self._get_types()
return self._get_resolved_typevars()

def _get_types(self) -> Dict[TypeVar, Type]:
def _get_resolved_typevars(self) -> Dict[TypeVar, Type]:
"""
See https://stackoverflow.com/questions/57706180/generict-base-class-how-to-get-type-of-t-from-within-instance/60984681#60984681
Do not call this inside the __init__() method, because at that point the relevant information are not present.
See also https://github.com/python/cpython/issues/90899'
"""

mapping: dict[TypeVar, type] = {}

non_generic_error = AssertionError(
f'{self.class_name} is not a generic class. To make it generic, declare it like: '
f'class {self.class_name}(Generic[T], GenericMixin):...')

if not hasattr(self, '__orig_bases__'):
raise non_generic_error

generic_base = get_generic_base(obj=self)
def collect(base, substitutions: dict[TypeVar, type]) -> None:
origin = get_origin(base)
args = get_args(base)

if origin is None:
return

if not generic_base:
for base in self.__orig_bases__: # type: ignore # (we checked existence above)
if not hasattr(base, '__origin__'):
continue
params = getattr(origin, '__parameters__', ())
resolved = {}
for param, arg in zip(params, args):
resolved_arg = substitutions.get(arg, arg) if isinstance(arg, TypeVar) else arg
mapping[param] = resolved_arg
resolved[param] = resolved_arg

generic_base = get_generic_base(base.__origin__)
for super_base in getattr(origin, '__orig_bases__', []):
collect(super_base, resolved)

if generic_base:
types = base.__args__
break
# Prefer __orig_class__ if available
cls = getattr(self, '__orig_class__', None)
if cls is not None:
collect(base=cls, substitutions={})
else:
if not hasattr(self, '__orig_class__'):
raise AssertionError(
f'You need to instantiate this class with type parameters! Example: {self.class_name}[int]()\n'
f'Also make sure that you do not call this in the __init__() method of your class! '
f'See also https://github.com/python/cpython/issues/90899')
for base in getattr(self.__class__, '__orig_bases__', []):
collect(base=base, substitutions={})

types = self.__orig_class__.__args__ # type: ignore
# Extra safety: ensure all declared typevars are resolved
all_params = set()
for cls in self.__class__.__mro__:
all_params.update(getattr(cls, '__parameters__', ()))

type_vars = generic_base.__args__
return {v: t for v, t in zip(type_vars, types)}
unresolved = {param for param in all_params if param not in mapping or isinstance(mapping[param], TypeVar)}
if unresolved:
raise AssertionError(
f'You need to instantiate this class with type parameters! Example: {self.class_name}[int]()\n'
f'Also make sure that you do not call this in the __init__() method of your class!\n'
f'Unresolved type variables: {unresolved}\n'
f'See also https://github.com/python/cpython/issues/90899'
)

return mapping

@property
def class_name(self) -> str:
Expand All @@ -100,14 +123,6 @@ def class_name(self) -> str:
return type(self).__name__


def get_generic_base(obj: Any) -> Optional[GenericAlias]:
generic_bases = [c for c in obj.__orig_bases__ if hasattr(c, '__origin__') and c.__origin__ == Generic]

if generic_bases:
return generic_bases[0] # this is safe because a class can have at most one "Generic" superclass

return None

if __name__ == '__main__':
import doctest
doctest.testmod(verbose=False, optionflags=doctest.ELLIPSIS)
16 changes: 16 additions & 0 deletions pedantic/tests/test_generic_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,19 @@ class MyClass(MyMixin, Gen[int]):

foo = MyClass(value=4)
assert foo.get_type() == {T: int}

def test_resolved_type_var_inheritance(self):
class Foo(Generic[T]): ...

class Bar(Foo[int], Generic[U], GenericMixin): ...

bar = Bar[str]()
assert bar.type_vars == {T: int, U: str}

def test_resolved_type_var_inheritance_2(self):
class Foo(Generic[T], GenericMixin): ...

class Bar(Foo[int], Generic[U]): ...

bar = Bar[str]()
assert bar.type_vars == {T: int, U: str}
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_content_from_readme(file_name: str = 'README.md') -> str:

setup(
name="pedantic",
version="2.1.10",
version="2.1.11",
python_requires='>=3.11.0',
packages=find_packages(),
install_requires=[],
Expand Down