11from __future__ import annotations
22
3- from typing import Callable , Sequence
3+ from typing import Callable , Iterable , Sequence
44
55import mypy .subtypes
66from mypy .erasetype import erase_typevars
77from mypy .expandtype import expand_type
8- from mypy .nodes import Context
8+ from mypy .nodes import Context , TypeInfo
9+ from mypy .type_visitor import TypeTranslator
10+ from mypy .typeops import get_all_type_vars
911from mypy .types import (
1012 AnyType ,
1113 CallableType ,
14+ Instance ,
15+ Parameters ,
16+ ParamSpecFlavor ,
1217 ParamSpecType ,
1318 PartialType ,
19+ ProperType ,
1420 Type ,
21+ TypeAliasType ,
1522 TypeVarId ,
1623 TypeVarLikeType ,
1724 TypeVarTupleType ,
1825 TypeVarType ,
1926 UninhabitedType ,
2027 UnpackType ,
2128 get_proper_type ,
29+ remove_dups ,
2230)
2331
2432
@@ -93,8 +101,7 @@ def apply_generic_arguments(
93101 bound or constraints, instead of giving an error.
94102 """
95103 tvars = callable .variables
96- min_arg_count = sum (not tv .has_default () for tv in tvars )
97- assert min_arg_count <= len (orig_types ) <= len (tvars )
104+ assert len (orig_types ) <= len (tvars )
98105 # Check that inferred type variable values are compatible with allowed
99106 # values and bounds. Also, promote subtype values to allowed values.
100107 # Create a map from type variable id to target type.
@@ -148,7 +155,7 @@ def apply_generic_arguments(
148155 type_is = None
149156
150157 # The callable may retain some type vars if only some were applied.
151- # TODO: move apply_poly() logic from checkexpr.py here when new inference
158+ # TODO: move apply_poly() logic here when new inference
152159 # becomes universally used (i.e. in all passes + in unification).
153160 # With this new logic we can actually *add* some new free variables.
154161 remaining_tvars : list [TypeVarLikeType ] = []
@@ -170,3 +177,126 @@ def apply_generic_arguments(
170177 type_guard = type_guard ,
171178 type_is = type_is ,
172179 )
180+
181+
182+ def apply_poly (tp : CallableType , poly_tvars : Sequence [TypeVarLikeType ]) -> CallableType | None :
183+ """Make free type variables generic in the type if possible.
184+
185+ This will translate the type `tp` while trying to create valid bindings for
186+ type variables `poly_tvars` while traversing the type. This follows the same rules
187+ as we do during semantic analysis phase, examples:
188+ * Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T
189+ * Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T)
190+ * List[T] -> None (not possible)
191+ """
192+ try :
193+ return tp .copy_modified (
194+ arg_types = [t .accept (PolyTranslator (poly_tvars )) for t in tp .arg_types ],
195+ ret_type = tp .ret_type .accept (PolyTranslator (poly_tvars )),
196+ variables = [],
197+ )
198+ except PolyTranslationError :
199+ return None
200+
201+
202+ class PolyTranslationError (Exception ):
203+ pass
204+
205+
206+ class PolyTranslator (TypeTranslator ):
207+ """Make free type variables generic in the type if possible.
208+
209+ See docstring for apply_poly() for details.
210+ """
211+
212+ def __init__ (
213+ self ,
214+ poly_tvars : Iterable [TypeVarLikeType ],
215+ bound_tvars : frozenset [TypeVarLikeType ] = frozenset (),
216+ seen_aliases : frozenset [TypeInfo ] = frozenset (),
217+ ) -> None :
218+ self .poly_tvars = set (poly_tvars )
219+ # This is a simplified version of TypeVarScope used during semantic analysis.
220+ self .bound_tvars = bound_tvars
221+ self .seen_aliases = seen_aliases
222+
223+ def collect_vars (self , t : CallableType | Parameters ) -> list [TypeVarLikeType ]:
224+ found_vars = []
225+ for arg in t .arg_types :
226+ for tv in get_all_type_vars (arg ):
227+ if isinstance (tv , ParamSpecType ):
228+ normalized : TypeVarLikeType = tv .copy_modified (
229+ flavor = ParamSpecFlavor .BARE , prefix = Parameters ([], [], [])
230+ )
231+ else :
232+ normalized = tv
233+ if normalized in self .poly_tvars and normalized not in self .bound_tvars :
234+ found_vars .append (normalized )
235+ return remove_dups (found_vars )
236+
237+ def visit_callable_type (self , t : CallableType ) -> Type :
238+ found_vars = self .collect_vars (t )
239+ self .bound_tvars |= set (found_vars )
240+ result = super ().visit_callable_type (t )
241+ self .bound_tvars -= set (found_vars )
242+
243+ assert isinstance (result , ProperType ) and isinstance (result , CallableType )
244+ result .variables = list (result .variables ) + found_vars
245+ return result
246+
247+ def visit_type_var (self , t : TypeVarType ) -> Type :
248+ if t in self .poly_tvars and t not in self .bound_tvars :
249+ raise PolyTranslationError ()
250+ return super ().visit_type_var (t )
251+
252+ def visit_param_spec (self , t : ParamSpecType ) -> Type :
253+ if t in self .poly_tvars and t not in self .bound_tvars :
254+ raise PolyTranslationError ()
255+ return super ().visit_param_spec (t )
256+
257+ def visit_type_var_tuple (self , t : TypeVarTupleType ) -> Type :
258+ if t in self .poly_tvars and t not in self .bound_tvars :
259+ raise PolyTranslationError ()
260+ return super ().visit_type_var_tuple (t )
261+
262+ def visit_type_alias_type (self , t : TypeAliasType ) -> Type :
263+ if not t .args :
264+ return t .copy_modified ()
265+ if not t .is_recursive :
266+ return get_proper_type (t ).accept (self )
267+ # We can't handle polymorphic application for recursive generic aliases
268+ # without risking an infinite recursion, just give up for now.
269+ raise PolyTranslationError ()
270+
271+ def visit_instance (self , t : Instance ) -> Type :
272+ if t .type .has_param_spec_type :
273+ # We need this special-casing to preserve the possibility to store a
274+ # generic function in an instance type. Things like
275+ # forall T . Foo[[x: T], T]
276+ # are not really expressible in current type system, but this looks like
277+ # a useful feature, so let's keep it.
278+ param_spec_index = next (
279+ i for (i , tv ) in enumerate (t .type .defn .type_vars ) if isinstance (tv , ParamSpecType )
280+ )
281+ p = get_proper_type (t .args [param_spec_index ])
282+ if isinstance (p , Parameters ):
283+ found_vars = self .collect_vars (p )
284+ self .bound_tvars |= set (found_vars )
285+ new_args = [a .accept (self ) for a in t .args ]
286+ self .bound_tvars -= set (found_vars )
287+
288+ repl = new_args [param_spec_index ]
289+ assert isinstance (repl , ProperType ) and isinstance (repl , Parameters )
290+ repl .variables = list (repl .variables ) + list (found_vars )
291+ return t .copy_modified (args = new_args )
292+ # There is the same problem with callback protocols as with aliases
293+ # (callback protocols are essentially more flexible aliases to callables).
294+ if t .args and t .type .is_protocol and t .type .protocol_members == ["__call__" ]:
295+ if t .type in self .seen_aliases :
296+ raise PolyTranslationError ()
297+ call = mypy .subtypes .find_member ("__call__" , t , t , is_operator = True )
298+ assert call is not None
299+ return call .accept (
300+ PolyTranslator (self .poly_tvars , self .bound_tvars , self .seen_aliases | {t .type })
301+ )
302+ return super ().visit_instance (t )
0 commit comments