diff --git a/CHANGELOG.md b/CHANGELOG.md index 0337367..c9b25c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,7 @@ # Changelog +## Pedantic 2.1.11 +- improve `GenericMixin` such that it also find bound type variables in parent classes + ## Pedantic 2.1.10 - added type check support for `functools.partial` - update dependencies diff --git a/pedantic/mixins/generic_mixin.py b/pedantic/mixins/generic_mixin.py index b83d452..bcb9d8c 100644 --- a/pedantic/mixins/generic_mixin.py +++ b/pedantic/mixins/generic_mixin.py @@ -1,5 +1,4 @@ -from types import GenericAlias -from typing import List, Type, TypeVar, Dict, Generic, Any, Optional +from typing import Type, TypeVar, Dict, get_origin, get_args class GenericMixin: @@ -11,7 +10,7 @@ class GenericMixin: >>> T = TypeVar('T') >>> U = TypeVar('U') >>> class Foo(Generic[T, U], GenericMixin): - ... values: List[T] + ... values: list[T] ... value: U >>> f = Foo[str, int]() >>> f.type_vars @@ -24,6 +23,8 @@ def type_var(self) -> Type: Get the type variable for this class. Use this for convenience if your class has only one type parameter. + DO NOT call this inside __init__()! + Example: >>> from typing import Generic, TypeVar >>> T = TypeVar('T') @@ -34,7 +35,7 @@ def type_var(self) -> Type: """ - types = self._get_types() + types = self._get_resolved_typevars() assert len(types) == 1, f'You have multiple type parameters. Please use "type_vars" instead of "type_var".' return list(types.values())[0] # type: ignore @@ -43,25 +44,30 @@ def type_vars(self) -> Dict[TypeVar, Type]: """ Returns the mapping of type variables to types. + DO NOT call this inside __init__()! + Example: >>> from typing import Generic, TypeVar >>> T = TypeVar('T') >>> U = TypeVar('U') >>> class Foo(Generic[T, U], GenericMixin): - ... values: List[T] + ... values: list[T] ... value: U >>> f = Foo[str, int]() >>> f.type_vars {~T: , ~U: } """ - return self._get_types() + return self._get_resolved_typevars() - def _get_types(self) -> Dict[TypeVar, Type]: + def _get_resolved_typevars(self) -> Dict[TypeVar, Type]: """ - See https://stackoverflow.com/questions/57706180/generict-base-class-how-to-get-type-of-t-from-within-instance/60984681#60984681 + Do not call this inside the __init__() method, because at that point the relevant information are not present. + See also https://github.com/python/cpython/issues/90899' """ + mapping: dict[TypeVar, type] = {} + non_generic_error = AssertionError( f'{self.class_name} is not a generic class. To make it generic, declare it like: ' f'class {self.class_name}(Generic[T], GenericMixin):...') @@ -69,29 +75,46 @@ def _get_types(self) -> Dict[TypeVar, Type]: if not hasattr(self, '__orig_bases__'): raise non_generic_error - generic_base = get_generic_base(obj=self) + def collect(base, substitutions: dict[TypeVar, type]) -> None: + origin = get_origin(base) + args = get_args(base) + + if origin is None: + return - if not generic_base: - for base in self.__orig_bases__: # type: ignore # (we checked existence above) - if not hasattr(base, '__origin__'): - continue + params = getattr(origin, '__parameters__', ()) + resolved = {} + for param, arg in zip(params, args): + resolved_arg = substitutions.get(arg, arg) if isinstance(arg, TypeVar) else arg + mapping[param] = resolved_arg + resolved[param] = resolved_arg - generic_base = get_generic_base(base.__origin__) + for super_base in getattr(origin, '__orig_bases__', []): + collect(super_base, resolved) - if generic_base: - types = base.__args__ - break + # Prefer __orig_class__ if available + cls = getattr(self, '__orig_class__', None) + if cls is not None: + collect(base=cls, substitutions={}) else: - if not hasattr(self, '__orig_class__'): - raise AssertionError( - f'You need to instantiate this class with type parameters! Example: {self.class_name}[int]()\n' - f'Also make sure that you do not call this in the __init__() method of your class! ' - f'See also https://github.com/python/cpython/issues/90899') + for base in getattr(self.__class__, '__orig_bases__', []): + collect(base=base, substitutions={}) - types = self.__orig_class__.__args__ # type: ignore + # Extra safety: ensure all declared typevars are resolved + all_params = set() + for cls in self.__class__.__mro__: + all_params.update(getattr(cls, '__parameters__', ())) - type_vars = generic_base.__args__ - return {v: t for v, t in zip(type_vars, types)} + unresolved = {param for param in all_params if param not in mapping or isinstance(mapping[param], TypeVar)} + if unresolved: + raise AssertionError( + f'You need to instantiate this class with type parameters! Example: {self.class_name}[int]()\n' + f'Also make sure that you do not call this in the __init__() method of your class!\n' + f'Unresolved type variables: {unresolved}\n' + f'See also https://github.com/python/cpython/issues/90899' + ) + + return mapping @property def class_name(self) -> str: @@ -100,14 +123,6 @@ def class_name(self) -> str: return type(self).__name__ -def get_generic_base(obj: Any) -> Optional[GenericAlias]: - generic_bases = [c for c in obj.__orig_bases__ if hasattr(c, '__origin__') and c.__origin__ == Generic] - - if generic_bases: - return generic_bases[0] # this is safe because a class can have at most one "Generic" superclass - - return None - if __name__ == '__main__': import doctest doctest.testmod(verbose=False, optionflags=doctest.ELLIPSIS) diff --git a/pedantic/tests/test_generic_mixin.py b/pedantic/tests/test_generic_mixin.py index fca1f0f..fac4634 100644 --- a/pedantic/tests/test_generic_mixin.py +++ b/pedantic/tests/test_generic_mixin.py @@ -103,3 +103,19 @@ class MyClass(MyMixin, Gen[int]): foo = MyClass(value=4) assert foo.get_type() == {T: int} + + def test_resolved_type_var_inheritance(self): + class Foo(Generic[T]): ... + + class Bar(Foo[int], Generic[U], GenericMixin): ... + + bar = Bar[str]() + assert bar.type_vars == {T: int, U: str} + + def test_resolved_type_var_inheritance_2(self): + class Foo(Generic[T], GenericMixin): ... + + class Bar(Foo[int], Generic[U]): ... + + bar = Bar[str]() + assert bar.type_vars == {T: int, U: str} diff --git a/setup.py b/setup.py index 7f1d732..e9079b3 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ def get_content_from_readme(file_name: str = 'README.md') -> str: setup( name="pedantic", - version="2.1.10", + version="2.1.11", python_requires='>=3.11.0', packages=find_packages(), install_requires=[],