|
18 | 18 | from mypy.checker_shared import ExpressionCheckerSharedApi |
19 | 19 | from mypy.checkmember import analyze_member_access, has_operator |
20 | 20 | from mypy.checkstrformat import StringFormatterChecker |
| 21 | +from mypy.constraints import SUBTYPE_OF, Constraint, infer_constraints |
21 | 22 | from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars |
22 | 23 | from mypy.errors import ErrorWatcher, report_internal_error |
23 | 24 | from mypy.expandtype import ( |
|
26 | 27 | freshen_all_functions_type_vars, |
27 | 28 | freshen_function_type_vars, |
28 | 29 | ) |
29 | | -from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments |
| 30 | +from mypy.infer import ( |
| 31 | + ArgumentInferContext, |
| 32 | + infer_constraints_for_callable, |
| 33 | + infer_function_type_arguments, |
| 34 | + infer_type_arguments, |
| 35 | + solve_constraints, |
| 36 | +) |
30 | 37 | from mypy.literals import literal |
31 | 38 | from mypy.maptype import map_instance_to_supertype |
32 | 39 | from mypy.meet import is_overlapping_types, narrow_declared_type |
@@ -1774,18 +1781,18 @@ def check_callable_call( |
1774 | 1781 | isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables |
1775 | 1782 | ) |
1776 | 1783 | callee = freshen_function_type_vars(callee) |
1777 | | - callee = self.infer_function_type_arguments_using_context(callee, context) |
1778 | | - if need_refresh: |
1779 | | - # Argument kinds etc. may have changed due to |
1780 | | - # ParamSpec or TypeVarTuple variables being replaced with an arbitrary |
1781 | | - # number of arguments; recalculate actual-to-formal map |
1782 | | - formal_to_actual = map_actuals_to_formals( |
1783 | | - arg_kinds, |
1784 | | - arg_names, |
1785 | | - callee.arg_kinds, |
1786 | | - callee.arg_names, |
1787 | | - lambda i: self.accept(args[i]), |
1788 | | - ) |
| 1784 | + # callee = self.infer_function_type_arguments_using_context(callee, context) |
| 1785 | + # if need_refresh: |
| 1786 | + # # Argument kinds etc. may have changed due to |
| 1787 | + # # ParamSpec or TypeVarTuple variables being replaced with an arbitrary |
| 1788 | + # # number of arguments; recalculate actual-to-formal map |
| 1789 | + # formal_to_actual = map_actuals_to_formals( |
| 1790 | + # arg_kinds, |
| 1791 | + # arg_names, |
| 1792 | + # callee.arg_kinds, |
| 1793 | + # callee.arg_names, |
| 1794 | + # lambda i: self.accept(args[i]), |
| 1795 | + # ) |
1789 | 1796 | callee = self.infer_function_type_arguments( |
1790 | 1797 | callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context |
1791 | 1798 | ) |
@@ -2085,6 +2092,88 @@ def infer_function_type_arguments_using_context( |
2085 | 2092 | callable, new_args, error_context, skip_unsatisfied=True |
2086 | 2093 | ) |
2087 | 2094 |
|
| 2095 | + def _infer_constraints_from_context( |
| 2096 | + self, callee: CallableType, error_context: Context |
| 2097 | + ) -> list[Constraint]: |
| 2098 | + """Unify callable return type to type context to infer type vars. |
| 2099 | +
|
| 2100 | + For example, if the return type is set[t] where 't' is a type variable |
| 2101 | + of callable, and if the context is set[int], return callable modified |
| 2102 | + by substituting 't' with 'int'. |
| 2103 | + """ |
| 2104 | + ctx = self.type_context[-1] |
| 2105 | + if not ctx: |
| 2106 | + return [] |
| 2107 | + # The return type may have references to type metavariables that |
| 2108 | + # we are inferring right now. We must consider them as indeterminate |
| 2109 | + # and they are not potential results; thus we replace them with the |
| 2110 | + # special ErasedType type. On the other hand, class type variables are |
| 2111 | + # valid results. |
| 2112 | + erased_ctx = replace_meta_vars(ctx, ErasedType()) |
| 2113 | + ret_type = callee.ret_type |
| 2114 | + if is_overlapping_none(ret_type) and is_overlapping_none(ctx): |
| 2115 | + # If both the context and the return type are optional, unwrap the optional, |
| 2116 | + # since in 99% cases this is what a user expects. In other words, we replace |
| 2117 | + # Optional[T] <: Optional[int] |
| 2118 | + # with |
| 2119 | + # T <: int |
| 2120 | + # while the former would infer T <: Optional[int]. |
| 2121 | + ret_type = remove_optional(ret_type) |
| 2122 | + erased_ctx = remove_optional(erased_ctx) |
| 2123 | + # |
| 2124 | + # TODO: Instead of this hack and the one below, we need to use outer and |
| 2125 | + # inner contexts at the same time. This is however not easy because of two |
| 2126 | + # reasons: |
| 2127 | + # * We need to support constraints like [1 <: 2, 2 <: X], i.e. with variables |
| 2128 | + # on both sides. (This is not too hard.) |
| 2129 | + # * We need to update all the inference "infrastructure", so that all |
| 2130 | + # variables in an expression are inferred at the same time. |
| 2131 | + # (And this is hard, also we need to be careful with lambdas that require |
| 2132 | + # two passes.) |
| 2133 | + proper_ret = get_proper_type(ret_type) |
| 2134 | + if ( |
| 2135 | + isinstance(proper_ret, TypeVarType) |
| 2136 | + or isinstance(proper_ret, UnionType) |
| 2137 | + and all(isinstance(get_proper_type(u), TypeVarType) for u in proper_ret.items) |
| 2138 | + ): |
| 2139 | + # Another special case: the return type is a type variable. If it's unrestricted, |
| 2140 | + # we could infer a too general type for the type variable if we use context, |
| 2141 | + # and this could result in confusing and spurious type errors elsewhere. |
| 2142 | + # |
| 2143 | + # So we give up and just use function arguments for type inference, with just two |
| 2144 | + # exceptions: |
| 2145 | + # |
| 2146 | + # 1. If the context is a generic instance type, actually use it as context, as |
| 2147 | + # this *seems* to usually be the reasonable thing to do. |
| 2148 | + # |
| 2149 | + # See also github issues #462 and #360. |
| 2150 | + # |
| 2151 | + # 2. If the context is some literal type, we want to "propagate" that information |
| 2152 | + # down so that we infer a more precise type for literal expressions. For example, |
| 2153 | + # the expression `3` normally has an inferred type of `builtins.int`: but if it's |
| 2154 | + # in a literal context like below, we want it to infer `Literal[3]` instead. |
| 2155 | + # |
| 2156 | + # def expects_literal(x: Literal[3]) -> None: pass |
| 2157 | + # def identity(x: T) -> T: return x |
| 2158 | + # |
| 2159 | + # expects_literal(identity(3)) # Should type-check |
| 2160 | + # TODO: we may want to add similar exception if all arguments are lambdas, since |
| 2161 | + # in this case external context is almost everything we have. |
| 2162 | + if not is_generic_instance(ctx) and not is_literal_type_like(ctx): |
| 2163 | + return [] |
| 2164 | + |
| 2165 | + constraints = infer_constraints(ret_type, erased_ctx, SUBTYPE_OF) |
| 2166 | + return constraints |
| 2167 | + |
| 2168 | + def _filter_args(self, args: list[Type | None]) -> list[Type | None]: |
| 2169 | + new_args: list[Type | None] = [] |
| 2170 | + for arg in args: |
| 2171 | + if has_uninhabited_component(arg) or has_erased_component(arg): |
| 2172 | + new_args.append(None) |
| 2173 | + else: |
| 2174 | + new_args.append(arg) |
| 2175 | + return new_args |
| 2176 | + |
2088 | 2177 | def infer_function_type_arguments( |
2089 | 2178 | self, |
2090 | 2179 | callee_type: CallableType, |
@@ -2131,6 +2220,133 @@ def infer_function_type_arguments( |
2131 | 2220 | context=self.argument_infer_context(), |
2132 | 2221 | strict=self.chk.in_checked_function(), |
2133 | 2222 | ) |
| 2223 | + old_inferred_args = inferred_args |
| 2224 | + new_inferred_args = None |
| 2225 | + |
| 2226 | + if True: # NEW CODE |
| 2227 | + extra_constraints = self._infer_constraints_from_context(callee_type, context) |
| 2228 | + |
| 2229 | + # outer_solution |
| 2230 | + _outer_solution = solve_constraints( |
| 2231 | + callee_type.variables, |
| 2232 | + extra_constraints, |
| 2233 | + strict=self.chk.in_checked_function(), |
| 2234 | + allow_polymorphic=False, |
| 2235 | + ) |
| 2236 | + outer_solution = (self._filter_args(_outer_solution[0]), _outer_solution[1]) |
| 2237 | + |
| 2238 | + # inner solution |
| 2239 | + constraints = infer_constraints_for_callable( |
| 2240 | + callee_type, |
| 2241 | + pass1_args, |
| 2242 | + arg_kinds, |
| 2243 | + arg_names, |
| 2244 | + formal_to_actual, |
| 2245 | + context=self.argument_infer_context(), |
| 2246 | + ) |
| 2247 | + inner_solution = solve_constraints( |
| 2248 | + callee_type.variables, |
| 2249 | + constraints, |
| 2250 | + strict=self.chk.in_checked_function(), |
| 2251 | + allow_polymorphic=False, |
| 2252 | + ) |
| 2253 | + |
| 2254 | + joint_solution = solve_constraints( |
| 2255 | + callee_type.variables, |
| 2256 | + extra_constraints + constraints, |
| 2257 | + strict=self.chk.in_checked_function(), |
| 2258 | + allow_polymorphic=False, |
| 2259 | + ) |
| 2260 | + |
| 2261 | + # check if we can use the joint solution, otherwise fallback to outer_solution |
| 2262 | + for var1, var2 in zip( |
| 2263 | + outer_solution[0], joint_solution[0] |
| 2264 | + ): # tuple[Type | None, Type | None] |
| 2265 | + if var2 is None and var1 is not None: |
| 2266 | + # using both constraints did not find a solution for this variable |
| 2267 | + # so we fallback to outer_solution, apply the solution, and then recompute the inner part |
| 2268 | + use_joint = False |
| 2269 | + break |
| 2270 | + else: |
| 2271 | + use_joint = True |
| 2272 | + |
| 2273 | + use_joint = True |
| 2274 | + use_outer = True |
| 2275 | + use_inner = True |
| 2276 | + |
| 2277 | + for outer_tp, inner_tp, joint_tp in zip( |
| 2278 | + outer_solution[0], inner_solution[0], joint_solution[0] |
| 2279 | + ): |
| 2280 | + if joint_tp is None and outer_tp is not None: |
| 2281 | + use_joint = False |
| 2282 | + if has_erased_component(joint_tp) and not has_erased_component(inner_tp): |
| 2283 | + # If the joint solution is erased, but outer is not, we use outer. |
| 2284 | + use_joint = False |
| 2285 | + if has_erased_component(outer_tp) and not has_erased_component(inner_tp): |
| 2286 | + use_outer = False |
| 2287 | + if has_erased_component(inner_tp): |
| 2288 | + use_inner = False |
| 2289 | + |
| 2290 | + if use_joint: |
| 2291 | + new_inferred_args = joint_solution[0] |
| 2292 | + # inferred_args = [ |
| 2293 | + # # Usually, joint_tp <: outer_tp (since superset of constraints), |
| 2294 | + # # fixes some cases where we would get `Literal[4]?` rather than `Literal[4]` |
| 2295 | + # (outer_tp if is_subtype(outer_tp, joint_tp) else joint_tp) |
| 2296 | + # for outer_tp, joint_tp in zip(outer_solution[0], joint_solution[0]) |
| 2297 | + # ] |
| 2298 | + elif use_outer: |
| 2299 | + # If we cannot use the joint solution, fallback to outer_solution |
| 2300 | + new_inferred_args = outer_solution[0] |
| 2301 | + |
| 2302 | + # Only substitute non-Uninhabited and non-erased types. |
| 2303 | + new_args: list[Type | None] = [] |
| 2304 | + for arg in new_inferred_args: |
| 2305 | + if has_uninhabited_component(arg) or has_erased_component(arg): |
| 2306 | + new_args.append(None) |
| 2307 | + else: |
| 2308 | + new_args.append(arg) |
| 2309 | + # Don't show errors after we have only used the outer context for inference. |
| 2310 | + # We will use argument context to infer more variables. |
| 2311 | + callee_type = self.apply_generic_arguments( |
| 2312 | + callee_type, new_args, context, skip_unsatisfied=True |
| 2313 | + ) |
| 2314 | + if need_refresh: |
| 2315 | + # Argument kinds etc. may have changed due to |
| 2316 | + # ParamSpec or TypeVarTuple variables being replaced with an arbitrary |
| 2317 | + # number of arguments; recalculate actual-to-formal map |
| 2318 | + formal_to_actual = map_actuals_to_formals( |
| 2319 | + arg_kinds, |
| 2320 | + arg_names, |
| 2321 | + callee_type.arg_kinds, |
| 2322 | + callee_type.arg_names, |
| 2323 | + lambda i: self.accept(args[i]), |
| 2324 | + ) |
| 2325 | + new_inferred_args, _ = infer_function_type_arguments( |
| 2326 | + callee_type, |
| 2327 | + pass1_args, |
| 2328 | + arg_kinds, |
| 2329 | + arg_names, |
| 2330 | + formal_to_actual, |
| 2331 | + context=self.argument_infer_context(), |
| 2332 | + strict=self.chk.in_checked_function(), |
| 2333 | + ) |
| 2334 | + elif use_inner: |
| 2335 | + new_inferred_args = inner_solution[0] |
| 2336 | + else: |
| 2337 | + raise RuntimeError("No solution found for function type arguments") |
| 2338 | + else: # OLD CODE |
| 2339 | + pass |
| 2340 | + |
| 2341 | + if True: # USE NEW CODE |
| 2342 | + inferred_args = new_inferred_args |
| 2343 | + else: # USE OLD CODE |
| 2344 | + inferred_args = old_inferred_args |
| 2345 | + |
| 2346 | + # show me |
| 2347 | + _1 = new_inferred_args |
| 2348 | + _2 = old_inferred_args |
| 2349 | + _3 = inferred_args |
2134 | 2350 |
|
2135 | 2351 | if 2 in arg_pass_nums: |
2136 | 2352 | # Second pass of type inference. |
|
0 commit comments