Skip to content

Commit a856e55

Browse files
authored
Use empty context as fallback for return statements (#19767)
Fixes #16924 Fixes #15886 Mypy uses external type context first, this can cause bad type inference in return statements (see example in test case added), usually we recommend a workaround to users like replacing: ```python return foo(x) ``` with ```python y = foo(x) return y ``` But this is a bit ugly, and more importantly we can essentially automatically try this workaround. This is what this PR adds. I checked performance impact, and don't see any (but for some reason noise level on my desktop is much higher now).
1 parent 70d0521 commit a856e55

File tree

4 files changed

+109
-5
lines changed

4 files changed

+109
-5
lines changed

mypy/checker.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
ContinueStmt,
9191
Decorator,
9292
DelStmt,
93+
DictExpr,
9394
EllipsisExpr,
9495
Expression,
9596
ExpressionStmt,
@@ -124,6 +125,7 @@
124125
RaiseStmt,
125126
RefExpr,
126127
ReturnStmt,
128+
SetExpr,
127129
StarExpr,
128130
Statement,
129131
StrExpr,
@@ -4859,6 +4861,42 @@ def visit_return_stmt(self, s: ReturnStmt) -> None:
48594861
self.check_return_stmt(s)
48604862
self.binder.unreachable()
48614863

4864+
def infer_context_dependent(
4865+
self, expr: Expression, type_ctx: Type, allow_none_func_call: bool
4866+
) -> ProperType:
4867+
"""Infer type of an expression with fallback to empty type context."""
4868+
with self.msg.filter_errors(
4869+
filter_errors=True, filter_deprecated=True, save_filtered_errors=True
4870+
) as msg:
4871+
with self.local_type_map as type_map:
4872+
typ = get_proper_type(
4873+
self.expr_checker.accept(
4874+
expr, type_ctx, allow_none_return=allow_none_func_call
4875+
)
4876+
)
4877+
if not msg.has_new_errors():
4878+
self.store_types(type_map)
4879+
return typ
4880+
4881+
# If there are errors with the original type context, try re-inferring in empty context.
4882+
original_messages = msg.filtered_errors()
4883+
original_type_map = type_map
4884+
with self.msg.filter_errors(
4885+
filter_errors=True, filter_deprecated=True, save_filtered_errors=True
4886+
) as msg:
4887+
with self.local_type_map as type_map:
4888+
alt_typ = get_proper_type(
4889+
self.expr_checker.accept(expr, None, allow_none_return=allow_none_func_call)
4890+
)
4891+
if not msg.has_new_errors() and is_subtype(alt_typ, type_ctx):
4892+
self.store_types(type_map)
4893+
return alt_typ
4894+
4895+
# If empty fallback didn't work, use results from the original type context.
4896+
self.msg.add_errors(original_messages)
4897+
self.store_types(original_type_map)
4898+
return typ
4899+
48624900
def check_return_stmt(self, s: ReturnStmt) -> None:
48634901
defn = self.scope.current_function()
48644902
if defn is not None:
@@ -4891,11 +4929,18 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
48914929
allow_none_func_call = is_lambda or declared_none_return or declared_any_return
48924930

48934931
# Return with a value.
4894-
typ = get_proper_type(
4895-
self.expr_checker.accept(
4896-
s.expr, return_type, allow_none_return=allow_none_func_call
4932+
if isinstance(s.expr, (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr)):
4933+
# For expressions that (strongly) depend on type context (i.e. those that
4934+
# are handled like a function call), we allow fallback to empty type context
4935+
# in case of errors, this improves user experience in some cases,
4936+
# see e.g. testReturnFallbackInference.
4937+
typ = self.infer_context_dependent(s.expr, return_type, allow_none_func_call)
4938+
else:
4939+
typ = get_proper_type(
4940+
self.expr_checker.accept(
4941+
s.expr, return_type, allow_none_return=allow_none_func_call
4942+
)
48974943
)
4898-
)
48994944
# Treat NotImplemented as having type Any, consistent with its
49004945
# definition in typeshed prior to python/typeshed#4222.
49014946
if (

mypy/errors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ def on_error(self, file: str, info: ErrorInfo) -> bool:
206206
"""
207207
if info.code == codes.DEPRECATED:
208208
# Deprecated is not a type error, so it is handled on opt-in basis here.
209-
return self._filter_deprecated
209+
if not self._filter_deprecated:
210+
return False
210211

211212
self._has_new_errors = True
212213
if isinstance(self._filter, bool):

test-data/unit/check-inference-context.test

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,3 +1540,45 @@ def f(x: dict[str, Union[str, None, int]]) -> None:
15401540
def g(x: Optional[dict[str, Any]], s: Optional[str]) -> None:
15411541
f(x or {'x': s})
15421542
[builtins fixtures/dict.pyi]
1543+
1544+
[case testReturnFallbackInferenceTuple]
1545+
from typing import TypeVar, Union
1546+
1547+
T = TypeVar("T")
1548+
def foo(x: list[T]) -> tuple[T, ...]: ...
1549+
1550+
def bar(x: list[int]) -> tuple[Union[str, int], ...]:
1551+
return foo(x)
1552+
1553+
def bar2(x: list[int]) -> tuple[Union[str, int], ...]:
1554+
y = foo(x)
1555+
return y
1556+
[builtins fixtures/tuple.pyi]
1557+
1558+
[case testReturnFallbackInferenceUnion]
1559+
from typing import Generic, TypeVar, Union
1560+
1561+
T = TypeVar("T")
1562+
1563+
class Cls(Generic[T]):
1564+
pass
1565+
1566+
def inner(c: Cls[T]) -> Union[T, int]:
1567+
return 1
1568+
1569+
def outer(c: Cls[T]) -> Union[T, int]:
1570+
return inner(c)
1571+
1572+
[case testReturnFallbackInferenceAsync]
1573+
from typing import Generic, TypeVar, Optional
1574+
1575+
T = TypeVar("T")
1576+
1577+
class Cls(Generic[T]):
1578+
pass
1579+
1580+
async def inner(c: Cls[T]) -> Optional[T]:
1581+
return None
1582+
1583+
async def outer(c: Cls[T]) -> Optional[T]:
1584+
return await inner(c)

test-data/unit/pythoneval.test

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,3 +2188,19 @@ reveal_type([*map(str, x)])
21882188
[out]
21892189
_testUnpackIteratorBuiltins.py:4: note: Revealed type is "builtins.list[builtins.int]"
21902190
_testUnpackIteratorBuiltins.py:5: note: Revealed type is "builtins.list[builtins.str]"
2191+
2192+
[case testReturnFallbackInferenceDict]
2193+
# Requires full dict stubs.
2194+
from typing import Dict, Mapping, TypeVar, Union
2195+
2196+
K = TypeVar("K")
2197+
V = TypeVar("V")
2198+
K2 = TypeVar("K2")
2199+
V2 = TypeVar("V2")
2200+
2201+
def func(one: Dict[K, V], two: Mapping[K2, V2]) -> Dict[Union[K, K2], Union[V, V2]]:
2202+
...
2203+
2204+
def caller(arg1: Mapping[K, V], arg2: Mapping[K2, V2]) -> Dict[Union[K, K2], Union[V, V2]]:
2205+
_arg1 = arg1 if isinstance(arg1, dict) else dict(arg1)
2206+
return func(_arg1, arg2)

0 commit comments

Comments
 (0)