Skip to content

Commit ac3e599

Browse files
committed
Combine the revealed types of multiple iteration steps in a more robust manner.
1 parent b678d9f commit ac3e599

File tree

8 files changed

+63
-48
lines changed

8 files changed

+63
-48
lines changed

mypy/checker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -643,8 +643,8 @@ def accept_loop(
643643

644644
for error_info in watcher.yield_error_infos():
645645
self.msg.fail(*error_info[:2], code=error_info[2])
646-
for note_info in watcher.yield_note_infos(self.options):
647-
self.note(*note_info)
646+
for note_info, context in watcher.yield_note_infos(self.options):
647+
self.msg.reveal_type(note_info, context)
648648

649649
# If exit_condition is set, assume it must be False on exit from the loop:
650650
if exit_condition:
@@ -3037,7 +3037,7 @@ def is_noop_for_reachability(self, s: Statement) -> bool:
30373037
if isinstance(s.expr, EllipsisExpr):
30383038
return True
30393039
elif isinstance(s.expr, CallExpr):
3040-
with self.expr_checker.msg.filter_errors():
3040+
with self.expr_checker.msg.filter_errors(filter_revealed_type=True):
30413041
typ = get_proper_type(
30423042
self.expr_checker.accept(
30433043
s.expr, allow_none_return=True, always_allow_any=True
@@ -4987,8 +4987,8 @@ def visit_try_stmt(self, s: TryStmt) -> None:
49874987

49884988
for error_info in watcher.yield_error_infos():
49894989
self.msg.fail(*error_info[:2], code=error_info[2])
4990-
for note_info in watcher.yield_note_infos(self.options):
4991-
self.msg.note(*note_info)
4990+
for note_info, context in watcher.yield_note_infos(self.options):
4991+
self.msg.reveal_type(note_info, context)
49924992

49934993
def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
49944994
"""Type check a try statement, ignoring the finally block.

mypy/errors.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from mypy.nodes import Context
1616
from mypy.options import Options
1717
from mypy.scope import Scope
18+
from mypy.typeops import make_simplified_union
19+
from mypy.types import Type
1820
from mypy.util import DEFAULT_SOURCE_OFFSET, is_typeshed_file
1921
from mypy.version import __version__ as mypy_version
2022

@@ -166,18 +168,24 @@ class ErrorWatcher:
166168
out by one of the ErrorWatcher instances.
167169
"""
168170

171+
# public attribute for the special treatment of `reveal_type` by
172+
# `MessageBuilder.reveal_type`:
173+
filter_revealed_type: bool
174+
169175
def __init__(
170176
self,
171177
errors: Errors,
172178
*,
173179
filter_errors: bool | Callable[[str, ErrorInfo], bool] = False,
174180
save_filtered_errors: bool = False,
175181
filter_deprecated: bool = False,
182+
filter_revealed_type: bool = False
176183
) -> None:
177184
self.errors = errors
178185
self._has_new_errors = False
179186
self._filter = filter_errors
180187
self._filter_deprecated = filter_deprecated
188+
self.filter_revealed_type = filter_revealed_type
181189
self._filtered: list[ErrorInfo] | None = [] if save_filtered_errors else None
182190

183191
def __enter__(self) -> Self:
@@ -236,15 +244,15 @@ class IterationDependentErrors:
236244
# the error report occurs but really all unreachable lines.
237245
unreachable_lines: list[set[int]]
238246

239-
# One set of revealed types for each `reveal_type` statement. Each created set can
240-
# grow during the iteration. Meaning of the tuple items: function_or_member, line,
241-
# column, end_line, end_column:
242-
revealed_types: dict[tuple[str | None, int, int, int, int], set[str]]
247+
# One list of revealed types for each `reveal_type` statement. Each created list
248+
# can grow during the iteration. Meaning of the tuple items: line, column,
249+
# end_line, end_column:
250+
revealed_types: dict[tuple[int, int, int | None, int | None], list[Type]]
243251

244252
def __init__(self) -> None:
245253
self.uselessness_errors = []
246254
self.unreachable_lines = []
247-
self.revealed_types = defaultdict(set)
255+
self.revealed_types = defaultdict(list)
248256

249257

250258
class IterationErrorWatcher(ErrorWatcher):
@@ -287,15 +295,6 @@ def on_error(self, file: str, info: ErrorInfo) -> bool:
287295
iter_errors.unreachable_lines[-1].update(range(info.line, info.end_line + 1))
288296
return True
289297

290-
if info.code == codes.MISC and info.message.startswith("Revealed type is "):
291-
key = info.function_or_member, info.line, info.column, info.end_line, info.end_column
292-
types = info.message.split('"')[1]
293-
if types.startswith("Union["):
294-
iter_errors.revealed_types[key].update(types[6:-1].split(", "))
295-
else:
296-
iter_errors.revealed_types[key].add(types)
297-
return True
298-
299298
return super().on_error(file, info)
300299

301300
def yield_error_infos(self) -> Iterator[tuple[str, Context, ErrorCode]]:
@@ -318,21 +317,14 @@ def yield_error_infos(self) -> Iterator[tuple[str, Context, ErrorCode]]:
318317
context.end_column = error_info[5]
319318
yield error_info[1], context, error_info[0]
320319

321-
def yield_note_infos(self, options: Options) -> Iterator[tuple[str, Context]]:
320+
def yield_note_infos(self, options: Options) -> Iterator[tuple[Type, Context]]:
322321
"""Yield all types revealed in at least one iteration step."""
323322

324323
for note_info, types in self.iteration_dependent_errors.revealed_types.items():
325-
sorted_ = sorted(types, key=lambda typ: typ.lower())
326-
if len(types) == 1:
327-
revealed = sorted_[0]
328-
elif options.use_or_syntax():
329-
revealed = " | ".join(sorted_)
330-
else:
331-
revealed = f"Union[{', '.join(sorted_)}]"
332-
context = Context(line=note_info[1], column=note_info[2])
333-
context.end_line = note_info[3]
334-
context.end_column = note_info[4]
335-
yield f'Revealed type is "{revealed}"', context
324+
context = Context(line=note_info[0], column=note_info[1])
325+
context.end_line = note_info[2]
326+
context.end_column = note_info[3]
327+
yield make_simplified_union(types), context
336328

337329

338330
class Errors:
@@ -596,18 +588,20 @@ def _add_error_info(self, file: str, info: ErrorInfo) -> None:
596588
if info.code in (IMPORT, IMPORT_UNTYPED, IMPORT_NOT_FOUND):
597589
self.seen_import_error = True
598590

591+
@property
592+
def watchers(self) -> Iterator[ErrorWatcher]:
593+
"""Yield the `ErrorWatcher` stack from top to bottom."""
594+
i = len(self._watchers)
595+
while i > 0:
596+
i -= 1
597+
yield self._watchers[i]
598+
599599
def _filter_error(self, file: str, info: ErrorInfo) -> bool:
600600
"""
601601
process ErrorWatcher stack from top to bottom,
602602
stopping early if error needs to be filtered out
603603
"""
604-
i = len(self._watchers)
605-
while i > 0:
606-
i -= 1
607-
w = self._watchers[i]
608-
if w.on_error(file, info):
609-
return True
610-
return False
604+
return any(w.on_error(file, info) for w in self.watchers)
611605

612606
def add_error_info(self, info: ErrorInfo) -> None:
613607
file, lines = info.origin

mypy/messages.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from mypy import errorcodes as codes, message_registry
2424
from mypy.erasetype import erase_type
2525
from mypy.errorcodes import ErrorCode
26-
from mypy.errors import ErrorInfo, Errors, ErrorWatcher
26+
from mypy.errors import ErrorInfo, Errors, ErrorWatcher, IterationErrorWatcher
2727
from mypy.nodes import (
2828
ARG_NAMED,
2929
ARG_NAMED_OPT,
@@ -188,12 +188,14 @@ def filter_errors(
188188
filter_errors: bool | Callable[[str, ErrorInfo], bool] = True,
189189
save_filtered_errors: bool = False,
190190
filter_deprecated: bool = False,
191+
filter_revealed_type: bool = False,
191192
) -> ErrorWatcher:
192193
return ErrorWatcher(
193194
self.errors,
194195
filter_errors=filter_errors,
195196
save_filtered_errors=save_filtered_errors,
196197
filter_deprecated=filter_deprecated,
198+
filter_revealed_type=filter_revealed_type,
197199
)
198200

199201
def add_errors(self, errors: list[ErrorInfo]) -> None:
@@ -1738,6 +1740,24 @@ def invalid_signature_for_special_method(
17381740
)
17391741

17401742
def reveal_type(self, typ: Type, context: Context) -> None:
1743+
1744+
# Search for an error watcher that modifies the "normal" behaviour (we do not
1745+
# rely on the normal `ErrorWatcher` filtering approach because we might need to
1746+
# collect the original types for a later unionised response):
1747+
for watcher in self.errors.watchers:
1748+
# The `reveal_type` statement should be ignored:
1749+
if watcher.filter_revealed_type:
1750+
return
1751+
# The `reveal_type` statement might be visited iteratively due to being
1752+
# placed in a loop or so. Hence, we collect the respective types of
1753+
# individual iterations so that we can report them all in one step later:
1754+
if isinstance(watcher, IterationErrorWatcher):
1755+
watcher.iteration_dependent_errors.revealed_types[
1756+
(context.line, context.column, context.end_line, context.end_column)
1757+
].append(typ)
1758+
return
1759+
1760+
# Nothing special here; just create the note:
17411761
visitor = TypeStrVisitor(options=self.options)
17421762
self.note(f'Revealed type is "{typ.accept(visitor)}"', context)
17431763

test-data/unit/check-inference.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ for var2 in [g, h, i, j, k, l]:
343343
reveal_type(var2) # N: Revealed type is "Union[builtins.int, builtins.str]"
344344

345345
for var3 in [m, n, o, p, q, r]:
346-
reveal_type(var3) # N: Revealed type is "Union[Any, builtins.int]"
346+
reveal_type(var3) # N: Revealed type is "Union[builtins.int, Any]"
347347

348348
T = TypeVar("T", bound=Type[Foo])
349349

@@ -1247,7 +1247,7 @@ class X(TypedDict):
12471247

12481248
x: X
12491249
for a in ("hourly", "daily"):
1250-
reveal_type(a) # N: Revealed type is "Union[Literal['daily']?, Literal['hourly']?]"
1250+
reveal_type(a) # N: Revealed type is "Union[Literal['hourly']?, Literal['daily']?]"
12511251
reveal_type(x[a]) # N: Revealed type is "builtins.int"
12521252
reveal_type(a.upper()) # N: Revealed type is "builtins.str"
12531253
c = a

test-data/unit/check-narrowing.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2346,7 +2346,7 @@ def f() -> bool: ...
23462346

23472347
y = None
23482348
while f():
2349-
reveal_type(y) # N: Revealed type is "Union[builtins.int, None]"
2349+
reveal_type(y) # N: Revealed type is "Union[None, builtins.int]"
23502350
y = 1
23512351
reveal_type(y) # N: Revealed type is "Union[builtins.int, None]"
23522352

test-data/unit/check-redefine2.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def f1() -> None:
628628
def f2() -> None:
629629
x = None
630630
while int():
631-
reveal_type(x) # N: Revealed type is "Union[builtins.str, None]"
631+
reveal_type(x) # N: Revealed type is "Union[None, builtins.str]"
632632
if int():
633633
x = ""
634634
reveal_type(x) # N: Revealed type is "Union[None, builtins.str]"
@@ -923,7 +923,7 @@ class X(TypedDict):
923923

924924
x: X
925925
for a in ("hourly", "daily"):
926-
reveal_type(a) # N: Revealed type is "Union[Literal['daily']?, Literal['hourly']?]"
926+
reveal_type(a) # N: Revealed type is "Union[Literal['hourly']?, Literal['daily']?]"
927927
reveal_type(x[a]) # N: Revealed type is "builtins.int"
928928
reveal_type(a.upper()) # N: Revealed type is "builtins.str"
929929
c = a

test-data/unit/check-typevar-tuple.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ from typing_extensions import Unpack
989989

990990
def pipeline(*xs: Unpack[Tuple[int, Unpack[Tuple[float, ...]], bool]]) -> None:
991991
for x in xs:
992-
reveal_type(x) # N: Revealed type is "Union[builtins.float, builtins.int]"
992+
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.float]"
993993
[builtins fixtures/tuple.pyi]
994994

995995
[case testFixedUnpackItemInInstanceArguments]

test-data/unit/check-union-error-syntax.test

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,18 @@ x = 3 # E: Incompatible types in assignment (expression has type "Literal[3]", v
6262
try:
6363
x = 1
6464
x = ""
65+
x = {1: ""}
6566
finally:
66-
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str]"
67+
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.str, builtins.dict[builtins.int, builtins.str]]"
6768
[builtins fixtures/isinstancelist.pyi]
6869

6970
[case testOrSyntaxRecombined]
7071
# flags: --python-version 3.10 --no-force-union-syntax --allow-redefinition-new --local-partial-types
7172
# The following revealed type is recombined because the finally body is visited twice.
72-
# ToDo: Improve this recombination logic, especially (but not only) for the "or syntax".
7373
try:
7474
x = 1
7575
x = ""
76+
x = {1: ""}
7677
finally:
77-
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | builtins.str"
78+
reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | builtins.dict[builtins.int, builtins.str]"
7879
[builtins fixtures/isinstancelist.pyi]

0 commit comments

Comments
 (0)