Skip to content

Commit 99725c9

Browse files
committed
Fix --strict-equality for iteratively visited code.
1 parent 660d911 commit 99725c9

File tree

3 files changed

+92
-6
lines changed

3 files changed

+92
-6
lines changed

mypy/errors.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,9 @@ def filtered_errors(self) -> list[ErrorInfo]:
230230

231231
class IterationDependentErrors:
232232
"""An `IterationDependentErrors` instance serves to collect the `unreachable`,
233-
`redundant-expr`, and `redundant-casts` errors, as well as the revealed types,
234-
handled by the individual `IterationErrorWatcher` instances sequentially applied to
235-
the same code section."""
233+
`redundant-expr`, and `redundant-casts` errors, as well as the revealed types and
234+
non-overlapping types, handled by the individual `IterationErrorWatcher` instances
235+
sequentially applied to the same code section."""
236236

237237
# One set of `unreachable`, `redundant-expr`, and `redundant-casts` errors per
238238
# iteration step. Meaning of the tuple items: ErrorCode, message, line, column,
@@ -248,9 +248,18 @@ class IterationDependentErrors:
248248
# end_line, end_column:
249249
revealed_types: dict[tuple[int, int, int | None, int | None], list[Type]]
250250

251+
# One dictionary of non-overlapping types per iteration step. Meaning of the key
252+
# tuple items: line, column, end_line, end_column, kind:
253+
nonoverlapping_types: list[
254+
dict[
255+
tuple[int, int, int | None, int | None, str], tuple[Type, Type]
256+
],
257+
]
258+
251259
def __init__(self) -> None:
252260
self.uselessness_errors = []
253261
self.unreachable_lines = []
262+
self.nonoverlapping_types = []
254263
self.revealed_types = defaultdict(list)
255264

256265
def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCode]]:
@@ -270,6 +279,39 @@ def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCod
270279
context.end_column = error_info[5]
271280
yield error_info[1], context, error_info[0]
272281

282+
283+
def yield_nonoverlapping_types(self) -> Iterator[
284+
tuple[tuple[list[Type], list[Type]], str, Context]
285+
]:
286+
"""Report expressions were non-overlapping types were detected for all iterations
287+
were the expression was reachable."""
288+
289+
selected = set()
290+
for candidate in set(chain(*self.nonoverlapping_types)):
291+
if all(
292+
(candidate in nonoverlap) or (candidate[0] in lines)
293+
for nonoverlap, lines in zip(
294+
self.nonoverlapping_types, self.unreachable_lines
295+
)
296+
):
297+
selected.add(candidate)
298+
299+
persistent_nonoverlaps: dict[
300+
tuple[int, int, int | None, int | None, str], tuple[list[Type], list[Type]]
301+
] = defaultdict(lambda: ([], []))
302+
for nonoverlaps in self.nonoverlapping_types:
303+
for candidate, (left, right) in nonoverlaps.items():
304+
if candidate in selected:
305+
types = persistent_nonoverlaps[candidate]
306+
types[0].append(left)
307+
types[1].append(right)
308+
309+
for error_info, types in persistent_nonoverlaps.items():
310+
context = Context(line=error_info[0], column=error_info[1])
311+
context.end_line = error_info[2]
312+
context.end_column = error_info[3]
313+
yield (types[0], types[1]), error_info[4], context
314+
273315
def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]:
274316
"""Yield all types revealed in at least one iteration step."""
275317

@@ -282,8 +324,9 @@ def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]:
282324

283325
class IterationErrorWatcher(ErrorWatcher):
284326
"""Error watcher that filters and separately collects `unreachable` errors,
285-
`redundant-expr` and `redundant-casts` errors, and revealed types when analysing
286-
code sections iteratively to help avoid making too-hasty reports."""
327+
`redundant-expr` and `redundant-casts` errors, and revealed types and
328+
non-overlapping types when analysing code sections iteratively to help avoid
329+
making too-hasty reports."""
287330

288331
iteration_dependent_errors: IterationDependentErrors
289332

@@ -304,6 +347,7 @@ def __init__(
304347
)
305348
self.iteration_dependent_errors = iteration_dependent_errors
306349
iteration_dependent_errors.uselessness_errors.append(set())
350+
iteration_dependent_errors.nonoverlapping_types.append({})
307351
iteration_dependent_errors.unreachable_lines.append(set())
308352

309353
def on_error(self, file: str, info: ErrorInfo) -> bool:

mypy/messages.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1625,6 +1625,19 @@ def incompatible_typevar_value(
16251625
)
16261626

16271627
def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) -> None:
1628+
1629+
# In loops (and similar cases), the same expression might be analysed multiple
1630+
# times and thereby confronted with different types. We only want to raise a
1631+
# `comparison-overlap` error if it occurs in all cases and therefore collect the
1632+
# respective types of the current iteration here so that we can report the error
1633+
# later if it is persistent over all iteration steps:
1634+
for watcher in self.errors.get_watchers():
1635+
if isinstance(watcher, IterationErrorWatcher):
1636+
watcher.iteration_dependent_errors.nonoverlapping_types[-1][
1637+
(ctx.line, ctx.column, ctx.end_line, ctx.end_column, kind)
1638+
] = (left, right)
1639+
return
1640+
16281641
left_str = "element" if kind == "container" else "left operand"
16291642
right_str = "container item" if kind == "container" else "right operand"
16301643
message = "Non-overlapping {} check ({} type: {}, {} type: {})"
@@ -2511,8 +2524,11 @@ def match_statement_inexhaustive_match(self, typ: Type, context: Context) -> Non
25112524
def iteration_dependent_errors(self, iter_errors: IterationDependentErrors) -> None:
25122525
for error_info in iter_errors.yield_uselessness_error_infos():
25132526
self.fail(*error_info[:2], code=error_info[2])
2527+
msu = mypy.typeops.make_simplified_union
2528+
for nonoverlaps, kind, context in iter_errors.yield_nonoverlapping_types():
2529+
self.dangerous_comparison(msu(nonoverlaps[0]), msu(nonoverlaps[1]), kind, context)
25142530
for types, context in iter_errors.yield_revealed_type_infos():
2515-
self.reveal_type(mypy.typeops.make_simplified_union(types), context)
2531+
self.reveal_type(msu(types), context)
25162532

25172533

25182534
def quote_type_string(type_string: str) -> str:

test-data/unit/check-narrowing.test

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2446,6 +2446,32 @@ while x is not None and b():
24462446
x = f()
24472447
[builtins fixtures/primitives.pyi]
24482448

2449+
[case testAvoidFalseNonOverlappingEqualityCheckInLoop1]
2450+
# flags: --allow-redefinition-new --local-partial-types --strict-equality
2451+
2452+
x = 1
2453+
while True:
2454+
if x == str():
2455+
break
2456+
x = str()
2457+
if x == int(): # E: Non-overlapping equality check (left operand type: "str", right operand type: "int")
2458+
break
2459+
[builtins fixtures/primitives.pyi]
2460+
2461+
[case testAvoidFalseNonOverlappingEqualityCheckInLoop2]
2462+
# flags: --allow-redefinition-new --local-partial-types --strict-equality
2463+
2464+
class A: ...
2465+
class B: ...
2466+
class C: ...
2467+
2468+
x = A()
2469+
while True:
2470+
if x == C(): # E: Non-overlapping equality check (left operand type: "Union[A, B]", right operand type: "C")
2471+
break
2472+
x = B()
2473+
[builtins fixtures/primitives.pyi]
2474+
24492475
[case testNarrowPromotionsInsideUnions1]
24502476

24512477
from typing import Union

0 commit comments

Comments
 (0)