diff --git a/mypy/errors.py b/mypy/errors.py index d75c1c62a1ed..6fce24d42d24 100644 --- a/mypy/errors.py +++ b/mypy/errors.py @@ -230,9 +230,9 @@ def filtered_errors(self) -> list[ErrorInfo]: class IterationDependentErrors: """An `IterationDependentErrors` instance serves to collect the `unreachable`, - `redundant-expr`, and `redundant-casts` errors, as well as the revealed types, - handled by the individual `IterationErrorWatcher` instances sequentially applied to - the same code section.""" + `redundant-expr`, and `redundant-casts` errors, as well as the revealed types and + non-overlapping types, handled by the individual `IterationErrorWatcher` instances + sequentially applied to the same code section.""" # One set of `unreachable`, `redundant-expr`, and `redundant-casts` errors per # iteration step. Meaning of the tuple items: ErrorCode, message, line, column, @@ -248,9 +248,16 @@ class IterationDependentErrors: # end_line, end_column: revealed_types: dict[tuple[int, int, int | None, int | None], list[Type]] + # One dictionary of non-overlapping types per iteration step. Meaning of the key + # tuple items: line, column, end_line, end_column, kind: + nonoverlapping_types: list[ + dict[tuple[int, int, int | None, int | None, str], tuple[Type, Type]], + ] + def __init__(self) -> None: self.uselessness_errors = [] self.unreachable_lines = [] + self.nonoverlapping_types = [] self.revealed_types = defaultdict(list) def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCode]]: @@ -270,6 +277,36 @@ def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCod context.end_column = error_info[5] yield error_info[1], context, error_info[0] + def yield_nonoverlapping_types( + self, + ) -> Iterator[tuple[tuple[list[Type], list[Type]], str, Context]]: + """Report expressions were non-overlapping types were detected for all iterations + were the expression was reachable.""" + + selected = set() + for candidate in set(chain(*self.nonoverlapping_types)): + if all( + (candidate in nonoverlap) or (candidate[0] in lines) + for nonoverlap, lines in zip(self.nonoverlapping_types, self.unreachable_lines) + ): + selected.add(candidate) + + persistent_nonoverlaps: dict[ + tuple[int, int, int | None, int | None, str], tuple[list[Type], list[Type]] + ] = defaultdict(lambda: ([], [])) + for nonoverlaps in self.nonoverlapping_types: + for candidate, (left, right) in nonoverlaps.items(): + if candidate in selected: + types = persistent_nonoverlaps[candidate] + types[0].append(left) + types[1].append(right) + + for error_info, types in persistent_nonoverlaps.items(): + context = Context(line=error_info[0], column=error_info[1]) + context.end_line = error_info[2] + context.end_column = error_info[3] + yield (types[0], types[1]), error_info[4], context + def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]: """Yield all types revealed in at least one iteration step.""" @@ -282,8 +319,9 @@ def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]: class IterationErrorWatcher(ErrorWatcher): """Error watcher that filters and separately collects `unreachable` errors, - `redundant-expr` and `redundant-casts` errors, and revealed types when analysing - code sections iteratively to help avoid making too-hasty reports.""" + `redundant-expr` and `redundant-casts` errors, and revealed types and + non-overlapping types when analysing code sections iteratively to help avoid + making too-hasty reports.""" iteration_dependent_errors: IterationDependentErrors @@ -304,6 +342,7 @@ def __init__( ) self.iteration_dependent_errors = iteration_dependent_errors iteration_dependent_errors.uselessness_errors.append(set()) + iteration_dependent_errors.nonoverlapping_types.append({}) iteration_dependent_errors.unreachable_lines.append(set()) def on_error(self, file: str, info: ErrorInfo) -> bool: diff --git a/mypy/messages.py b/mypy/messages.py index f626d4c71916..95c74a14de8c 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1625,6 +1625,21 @@ def incompatible_typevar_value( ) def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) -> None: + + # In loops (and similar cases), the same expression might be analysed multiple + # times and thereby confronted with different types. We only want to raise a + # `comparison-overlap` error if it occurs in all cases and therefore collect the + # respective types of the current iteration here so that we can report the error + # later if it is persistent over all iteration steps: + for watcher in self.errors.get_watchers(): + if watcher._filter: + break + if isinstance(watcher, IterationErrorWatcher): + watcher.iteration_dependent_errors.nonoverlapping_types[-1][ + (ctx.line, ctx.column, ctx.end_line, ctx.end_column, kind) + ] = (left, right) + return + left_str = "element" if kind == "container" else "left operand" right_str = "container item" if kind == "container" else "right operand" message = "Non-overlapping {} check ({} type: {}, {} type: {})" @@ -2511,8 +2526,11 @@ def match_statement_inexhaustive_match(self, typ: Type, context: Context) -> Non def iteration_dependent_errors(self, iter_errors: IterationDependentErrors) -> None: for error_info in iter_errors.yield_uselessness_error_infos(): self.fail(*error_info[:2], code=error_info[2]) + msu = mypy.typeops.make_simplified_union + for nonoverlaps, kind, context in iter_errors.yield_nonoverlapping_types(): + self.dangerous_comparison(msu(nonoverlaps[0]), msu(nonoverlaps[1]), kind, context) for types, context in iter_errors.yield_revealed_type_infos(): - self.reveal_type(mypy.typeops.make_simplified_union(types), context) + self.reveal_type(msu(types), context) def quote_type_string(type_string: str) -> str: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 7fffd3ce94e5..04f2c2fccd34 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2446,6 +2446,41 @@ while x is not None and b(): x = f() [builtins fixtures/primitives.pyi] +[case testAvoidFalseNonOverlappingEqualityCheckInLoop1] +# flags: --allow-redefinition-new --local-partial-types --strict-equality + +x = 1 +while True: + if x == str(): + break + x = str() + if x == int(): # E: Non-overlapping equality check (left operand type: "str", right operand type: "int") + break +[builtins fixtures/primitives.pyi] + +[case testAvoidFalseNonOverlappingEqualityCheckInLoop2] +# flags: --allow-redefinition-new --local-partial-types --strict-equality + +class A: ... +class B: ... +class C: ... + +x = A() +while True: + if x == C(): # E: Non-overlapping equality check (left operand type: "Union[A, B]", right operand type: "C") + break + x = B() +[builtins fixtures/primitives.pyi] + +[case testAvoidFalseNonOverlappingEqualityCheckInLoop3] +# flags: --strict-equality + +for y in [1.0]: + if y is not None or y != "None": + ... + +[builtins fixtures/primitives.pyi] + [case testNarrowPromotionsInsideUnions1] from typing import Union