Skip to content
49 changes: 44 additions & 5 deletions mypy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ def filtered_errors(self) -> list[ErrorInfo]:

class IterationDependentErrors:
"""An `IterationDependentErrors` instance serves to collect the `unreachable`,
`redundant-expr`, and `redundant-casts` errors, as well as the revealed types,
handled by the individual `IterationErrorWatcher` instances sequentially applied to
the same code section."""
`redundant-expr`, and `redundant-casts` errors, as well as the revealed types and
non-overlapping types, handled by the individual `IterationErrorWatcher` instances
sequentially applied to the same code section."""

# One set of `unreachable`, `redundant-expr`, and `redundant-casts` errors per
# iteration step. Meaning of the tuple items: ErrorCode, message, line, column,
Expand All @@ -249,9 +249,16 @@ class IterationDependentErrors:
# end_line, end_column:
revealed_types: dict[tuple[int, int, int | None, int | None], list[Type]]

# One dictionary of non-overlapping types per iteration step. Meaning of the key
# tuple items: line, column, end_line, end_column, kind:
nonoverlapping_types: list[
dict[tuple[int, int, int | None, int | None, str], tuple[Type, Type]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe NamedTuple or at least TypeAlias? I'm personally lost in brackets here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, TypeAlias doesn't help much. But using NamedTuple (here and in the similar cases above) would definitely increase readability. I will adjust it if there are no performance concerns.

]

def __init__(self) -> None:
self.uselessness_errors = []
self.unreachable_lines = []
self.nonoverlapping_types = []
self.revealed_types = defaultdict(list)

def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCode]]:
Expand All @@ -271,6 +278,36 @@ def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCod
context.end_column = error_info[5]
yield error_info[1], context, error_info[0]

def yield_nonoverlapping_types(
self,
) -> Iterator[tuple[tuple[list[Type], list[Type]], str, Context]]:
"""Report expressions were non-overlapping types were detected for all iterations
were the expression was reachable."""

selected = set()
for candidate in set(chain(*self.nonoverlapping_types)):
if all(
(candidate in nonoverlap) or (candidate[0] in lines)
for nonoverlap, lines in zip(self.nonoverlapping_types, self.unreachable_lines)
):
selected.add(candidate)

persistent_nonoverlaps: dict[
tuple[int, int, int | None, int | None, str], tuple[list[Type], list[Type]]
] = defaultdict(lambda: ([], []))
for nonoverlaps in self.nonoverlapping_types:
for candidate, (left, right) in nonoverlaps.items():
if candidate in selected:
types = persistent_nonoverlaps[candidate]
types[0].append(left)
types[1].append(right)

for error_info, types in persistent_nonoverlaps.items():
context = Context(line=error_info[0], column=error_info[1])
context.end_line = error_info[2]
context.end_column = error_info[3]
yield (types[0], types[1]), error_info[4], context

def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]:
"""Yield all types revealed in at least one iteration step."""

Expand All @@ -283,8 +320,9 @@ def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]:

class IterationErrorWatcher(ErrorWatcher):
"""Error watcher that filters and separately collects `unreachable` errors,
`redundant-expr` and `redundant-casts` errors, and revealed types when analysing
code sections iteratively to help avoid making too-hasty reports."""
`redundant-expr` and `redundant-casts` errors, and revealed types and
non-overlapping types when analysing code sections iteratively to help avoid
making too-hasty reports."""

iteration_dependent_errors: IterationDependentErrors

Expand All @@ -305,6 +343,7 @@ def __init__(
)
self.iteration_dependent_errors = iteration_dependent_errors
iteration_dependent_errors.uselessness_errors.append(set())
iteration_dependent_errors.nonoverlapping_types.append({})
iteration_dependent_errors.unreachable_lines.append(set())

def on_error(self, file: str, info: ErrorInfo) -> bool:
Expand Down
20 changes: 19 additions & 1 deletion mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,21 @@ def incompatible_typevar_value(
)

def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) -> None:

# In loops (and similar cases), the same expression might be analysed multiple
# times and thereby confronted with different types. We only want to raise a
# `comparison-overlap` error if it occurs in all cases and therefore collect the
# respective types of the current iteration here so that we can report the error
# later if it is persistent over all iteration steps:
for watcher in self.errors.get_watchers():
if watcher._filter:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move this logic into the watcher itself? This access to private followed by special-casing on watcher type is a bit painful to read, IMO this would be clearer as

for watcher in self.errors.get_watchers():
    if watcher.store_nonoverlapping_types(ctx, kind, left, right):
        return

Where store_nonoverlapping_types (not the best name) is a no-op in ErrorWatcher, overridden with this block in IterationErrorWatcher

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ough, you really need both break and return, sorry.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to excuse. I am obviously using things a little differently from how they were originally intended. Hence, I highly appreciate any thoughts on improving readability.

break
if isinstance(watcher, IterationErrorWatcher):
watcher.iteration_dependent_errors.nonoverlapping_types[-1][
(ctx.line, ctx.column, ctx.end_line, ctx.end_column, kind)
] = (left, right)
return

left_str = "element" if kind == "container" else "left operand"
right_str = "container item" if kind == "container" else "right operand"
message = "Non-overlapping {} check ({} type: {}, {} type: {})"
Expand Down Expand Up @@ -2514,8 +2529,11 @@ def match_statement_inexhaustive_match(self, typ: Type, context: Context) -> Non
def iteration_dependent_errors(self, iter_errors: IterationDependentErrors) -> None:
for error_info in iter_errors.yield_uselessness_error_infos():
self.fail(*error_info[:2], code=error_info[2])
msu = mypy.typeops.make_simplified_union
for nonoverlaps, kind, context in iter_errors.yield_nonoverlapping_types():
self.dangerous_comparison(msu(nonoverlaps[0]), msu(nonoverlaps[1]), kind, context)
for types, context in iter_errors.yield_revealed_type_infos():
self.reveal_type(mypy.typeops.make_simplified_union(types), context)
self.reveal_type(msu(types), context)


def quote_type_string(type_string: str) -> str:
Expand Down
35 changes: 35 additions & 0 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -2446,6 +2446,41 @@ while x is not None and b():
x = f()
[builtins fixtures/primitives.pyi]

[case testAvoidFalseNonOverlappingEqualityCheckInLoop1]
# flags: --allow-redefinition-new --local-partial-types --strict-equality

x = 1
while True:
if x == str():
break
x = str()
if x == int(): # E: Non-overlapping equality check (left operand type: "str", right operand type: "int")
break
[builtins fixtures/primitives.pyi]

[case testAvoidFalseNonOverlappingEqualityCheckInLoop2]
# flags: --allow-redefinition-new --local-partial-types --strict-equality

class A: ...
class B: ...
class C: ...

x = A()
while True:
if x == C(): # E: Non-overlapping equality check (left operand type: "Union[A, B]", right operand type: "C")
break
x = B()
[builtins fixtures/primitives.pyi]

[case testAvoidFalseNonOverlappingEqualityCheckInLoop3]
# flags: --strict-equality

for y in [1.0]:
if y is not None or y != "None":
...

[builtins fixtures/primitives.pyi]

[case testNarrowPromotionsInsideUnions1]

from typing import Union
Expand Down