1- from types import GenericAlias
2- from typing import List , Type , TypeVar , Dict , Generic , Any , Optional
1+ from typing import Type , TypeVar , Dict , get_origin , get_args
32
43
54class GenericMixin :
@@ -11,7 +10,7 @@ class GenericMixin:
1110 >>> T = TypeVar('T')
1211 >>> U = TypeVar('U')
1312 >>> class Foo(Generic[T, U], GenericMixin):
14- ... values: List [T]
13+ ... values: list [T]
1514 ... value: U
1615 >>> f = Foo[str, int]()
1716 >>> f.type_vars
@@ -24,6 +23,8 @@ def type_var(self) -> Type:
2423 Get the type variable for this class.
2524 Use this for convenience if your class has only one type parameter.
2625
26+ DO NOT call this inside __init__()!
27+
2728 Example:
2829 >>> from typing import Generic, TypeVar
2930 >>> T = TypeVar('T')
@@ -34,7 +35,7 @@ def type_var(self) -> Type:
3435 <class 'float'>
3536 """
3637
37- types = self ._get_types ()
38+ types = self ._get_resolved_typevars ()
3839 assert len (types ) == 1 , f'You have multiple type parameters. Please use "type_vars" instead of "type_var".'
3940 return list (types .values ())[0 ] # type: ignore
4041
@@ -43,55 +44,77 @@ def type_vars(self) -> Dict[TypeVar, Type]:
4344 """
4445 Returns the mapping of type variables to types.
4546
47+ DO NOT call this inside __init__()!
48+
4649 Example:
4750 >>> from typing import Generic, TypeVar
4851 >>> T = TypeVar('T')
4952 >>> U = TypeVar('U')
5053 >>> class Foo(Generic[T, U], GenericMixin):
51- ... values: List [T]
54+ ... values: list [T]
5255 ... value: U
5356 >>> f = Foo[str, int]()
5457 >>> f.type_vars
5558 {~T: <class 'str'>, ~U: <class 'int'>}
5659 """
5760
58- return self ._get_types ()
61+ return self ._get_resolved_typevars ()
5962
60- def _get_types (self ) -> Dict [TypeVar , Type ]:
63+ def _get_resolved_typevars (self ) -> Dict [TypeVar , Type ]:
6164 """
62- See https://stackoverflow.com/questions/57706180/generict-base-class-how-to-get-type-of-t-from-within-instance/60984681#60984681
65+ Do not call this inside the __init__() method, because at that point the relevant information are not present.
66+ See also https://github.com/python/cpython/issues/90899'
6367 """
6468
69+ mapping : dict [TypeVar , type ] = {}
70+
6571 non_generic_error = AssertionError (
6672 f'{ self .class_name } is not a generic class. To make it generic, declare it like: '
6773 f'class { self .class_name } (Generic[T], GenericMixin):...' )
6874
6975 if not hasattr (self , '__orig_bases__' ):
7076 raise non_generic_error
7177
72- generic_base = get_generic_base (obj = self )
78+ def collect (base , substitutions : dict [TypeVar , type ]) -> None :
79+ origin = get_origin (base )
80+ args = get_args (base )
81+
82+ if origin is None :
83+ return
7384
74- if not generic_base :
75- for base in self .__orig_bases__ : # type: ignore # (we checked existence above)
76- if not hasattr (base , '__origin__' ):
77- continue
85+ params = getattr (origin , '__parameters__' , ())
86+ resolved = {}
87+ for param , arg in zip (params , args ):
88+ resolved_arg = substitutions .get (arg , arg ) if isinstance (arg , TypeVar ) else arg
89+ mapping [param ] = resolved_arg
90+ resolved [param ] = resolved_arg
7891
79- generic_base = get_generic_base (base .__origin__ )
92+ for super_base in getattr (origin , '__orig_bases__' , []):
93+ collect (super_base , resolved )
8094
81- if generic_base :
82- types = base .__args__
83- break
95+ # Prefer __orig_class__ if available
96+ cls = getattr (self , '__orig_class__' , None )
97+ if cls is not None :
98+ collect (base = cls , substitutions = {})
8499 else :
85- if not hasattr (self , '__orig_class__' ):
86- raise AssertionError (
87- f'You need to instantiate this class with type parameters! Example: { self .class_name } [int]()\n '
88- f'Also make sure that you do not call this in the __init__() method of your class! '
89- f'See also https://github.com/python/cpython/issues/90899' )
100+ for base in getattr (self .__class__ , '__orig_bases__' , []):
101+ collect (base = base , substitutions = {})
90102
91- types = self .__orig_class__ .__args__ # type: ignore
103+ # Extra safety: ensure all declared typevars are resolved
104+ all_params = set ()
105+ for cls in self .__class__ .__mro__ :
106+ all_params .update (getattr (cls , '__parameters__' , ()))
92107
93- type_vars = generic_base .__args__
94- return {v : t for v , t in zip (type_vars , types )}
108+ unresolved = {param for param in all_params if param not in mapping or isinstance (mapping [param ], TypeVar )}
109+ if unresolved :
110+ raise AssertionError (
111+ f'You need to instantiate this class with type parameters! Example: { self .class_name } [int]()\n '
112+ f'Also make sure that you do not call this in the __init__() method of your class!\n '
113+ f'Unresolved type variables: { unresolved } \n '
114+ f'See also https://github.com/python/cpython/issues/90899'
115+ )
116+
117+ return mapping
95118
96119 @property
97120 def class_name (self ) -> str :
@@ -100,14 +123,6 @@ def class_name(self) -> str:
100123 return type (self ).__name__
101124
102125
103- def get_generic_base (obj : Any ) -> Optional [GenericAlias ]:
104- generic_bases = [c for c in obj .__orig_bases__ if hasattr (c , '__origin__' ) and c .__origin__ == Generic ]
105-
106- if generic_bases :
107- return generic_bases [0 ] # this is safe because a class can have at most one "Generic" superclass
108-
109- return None
110-
111126if __name__ == '__main__' :
112127 import doctest
113128 doctest .testmod (verbose = False , optionflags = doctest .ELLIPSIS )
0 commit comments