|
90 | 90 | ContinueStmt, |
91 | 91 | Decorator, |
92 | 92 | DelStmt, |
| 93 | + DictExpr, |
93 | 94 | EllipsisExpr, |
94 | 95 | Expression, |
95 | 96 | ExpressionStmt, |
|
124 | 125 | RaiseStmt, |
125 | 126 | RefExpr, |
126 | 127 | ReturnStmt, |
| 128 | + SetExpr, |
127 | 129 | StarExpr, |
128 | 130 | Statement, |
129 | 131 | StrExpr, |
@@ -4859,6 +4861,42 @@ def visit_return_stmt(self, s: ReturnStmt) -> None: |
4859 | 4861 | self.check_return_stmt(s) |
4860 | 4862 | self.binder.unreachable() |
4861 | 4863 |
|
| 4864 | + def infer_context_dependent( |
| 4865 | + self, expr: Expression, type_ctx: Type, allow_none_func_call: bool |
| 4866 | + ) -> ProperType: |
| 4867 | + """Infer type of an expression with fallback to empty type context.""" |
| 4868 | + with self.msg.filter_errors( |
| 4869 | + filter_errors=True, filter_deprecated=True, save_filtered_errors=True |
| 4870 | + ) as msg: |
| 4871 | + with self.local_type_map as type_map: |
| 4872 | + typ = get_proper_type( |
| 4873 | + self.expr_checker.accept( |
| 4874 | + expr, type_ctx, allow_none_return=allow_none_func_call |
| 4875 | + ) |
| 4876 | + ) |
| 4877 | + if not msg.has_new_errors(): |
| 4878 | + self.store_types(type_map) |
| 4879 | + return typ |
| 4880 | + |
| 4881 | + # If there are errors with the original type context, try re-inferring in empty context. |
| 4882 | + original_messages = msg.filtered_errors() |
| 4883 | + original_type_map = type_map |
| 4884 | + with self.msg.filter_errors( |
| 4885 | + filter_errors=True, filter_deprecated=True, save_filtered_errors=True |
| 4886 | + ) as msg: |
| 4887 | + with self.local_type_map as type_map: |
| 4888 | + alt_typ = get_proper_type( |
| 4889 | + self.expr_checker.accept(expr, None, allow_none_return=allow_none_func_call) |
| 4890 | + ) |
| 4891 | + if not msg.has_new_errors() and is_subtype(alt_typ, type_ctx): |
| 4892 | + self.store_types(type_map) |
| 4893 | + return alt_typ |
| 4894 | + |
| 4895 | + # If empty fallback didn't work, use results from the original type context. |
| 4896 | + self.msg.add_errors(original_messages) |
| 4897 | + self.store_types(original_type_map) |
| 4898 | + return typ |
| 4899 | + |
4862 | 4900 | def check_return_stmt(self, s: ReturnStmt) -> None: |
4863 | 4901 | defn = self.scope.current_function() |
4864 | 4902 | if defn is not None: |
@@ -4891,11 +4929,18 @@ def check_return_stmt(self, s: ReturnStmt) -> None: |
4891 | 4929 | allow_none_func_call = is_lambda or declared_none_return or declared_any_return |
4892 | 4930 |
|
4893 | 4931 | # Return with a value. |
4894 | | - typ = get_proper_type( |
4895 | | - self.expr_checker.accept( |
4896 | | - s.expr, return_type, allow_none_return=allow_none_func_call |
| 4932 | + if isinstance(s.expr, (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr)): |
| 4933 | + # For expressions that (strongly) depend on type context (i.e. those that |
| 4934 | + # are handled like a function call), we allow fallback to empty type context |
| 4935 | + # in case of errors, this improves user experience in some cases, |
| 4936 | + # see e.g. testReturnFallbackInference. |
| 4937 | + typ = self.infer_context_dependent(s.expr, return_type, allow_none_func_call) |
| 4938 | + else: |
| 4939 | + typ = get_proper_type( |
| 4940 | + self.expr_checker.accept( |
| 4941 | + s.expr, return_type, allow_none_return=allow_none_func_call |
| 4942 | + ) |
4897 | 4943 | ) |
4898 | | - ) |
4899 | 4944 | # Treat NotImplemented as having type Any, consistent with its |
4900 | 4945 | # definition in typeshed prior to python/typeshed#4222. |
4901 | 4946 | if ( |
|
0 commit comments