diff --git a/mypy/partially_defined.py b/mypy/partially_defined.py index 38154cf697e1..ada28ec4bd3c 100644 --- a/mypy/partially_defined.py +++ b/mypy/partially_defined.py @@ -295,6 +295,8 @@ def is_undefined(self, name: str) -> bool: class Loop: def __init__(self) -> None: self.has_break = False + # variables defined in every loop branch with `break` + self.break_vars: set[str] | None = None class PossiblyUndefinedVariableVisitor(ExtendedTraverserVisitor): @@ -472,6 +474,9 @@ def visit_for_stmt(self, o: ForStmt) -> None: has_break = loop.has_break if has_break: self.tracker.start_branch_statement() + if loop.break_vars is not None: + for bv in loop.break_vars: + self.tracker.record_definition(bv) self.tracker.next_branch() o.else_body.accept(self) if has_break: @@ -504,6 +509,14 @@ def visit_break_stmt(self, o: BreakStmt) -> None: super().visit_break_stmt(o) if self.loops: self.loops[-1].has_break = True + # Track variables that are definitely defined at the point of break + if len(self.tracker._scope().branch_stmts) > 0: + branch = self.tracker._scope().branch_stmts[-1].branches[-1] + if self.loops[-1].break_vars is None: + self.loops[-1].break_vars = set(branch.must_be_defined) + else: + # we only want variables that have been defined in each branch + self.loops[-1].break_vars.intersection_update(branch.must_be_defined) self.tracker.skip_branch() def visit_expression_stmt(self, o: ExpressionStmt) -> None: diff --git a/test-data/unit/check-possibly-undefined.test b/test-data/unit/check-possibly-undefined.test index ae277949c049..3bf52e5b8847 100644 --- a/test-data/unit/check-possibly-undefined.test +++ b/test-data/unit/check-possibly-undefined.test @@ -1043,3 +1043,126 @@ def foo(x: Union[int, str]) -> None: assert_never(x) f # OK [builtins fixtures/tuple.pyi] + +[case testForElseWithBreakInTryExceptContinue] +# flags: --enable-error-code possibly-undefined +# Test for issue where variable defined before break in try block +# was incorrectly reported as undefined when except has continue +def random() -> float: return 0.5 + +if random(): + for i in range(10): + try: + value = random() + break + except Exception: + continue + else: + raise RuntimeError + + print(value) # Should not error - value is defined if we broke +else: + value = random() + +print(value) # Should not error +[builtins fixtures/for_else_exception.pyi] +[typing fixtures/typing-medium.pyi] + +[case testForElseWithBreakInTryExceptContinueNoIf] +# flags: --enable-error-code possibly-undefined +# Simpler version without if statement +def random() -> float: return 0.5 + +for i in range(10): + try: + value = random() + break + except Exception: + continue +else: + raise RuntimeError + +print(value) # Should not error +[builtins fixtures/for_else_exception.pyi] +[typing fixtures/typing-medium.pyi] + +[case testForElseWithBreakInTryExceptPass] +# flags: --enable-error-code possibly-undefined +# Version with pass instead of continue - should also work +def random() -> float: return 0.5 + +if random(): + for i in range(10): + try: + value = random() + break + except Exception: + pass + else: + raise RuntimeError + + print(value) # Should not error +else: + value = random() + +print(value) # Should not error +[builtins fixtures/for_else_exception.pyi] +[typing fixtures/typing-medium.pyi] + +[case testForElseWithConditionalDefBeforeBreak] +# flags: --enable-error-code possibly-undefined +# Test that conditional definition before break still works correctly +def random() -> float: return 0.5 + +if random(): + for i in range(10): + try: + if i > 10: + value = random() + break + except Exception: + continue + else: + raise RuntimeError + + print(value) # Should not error (though might be undefined at runtime) +else: + value = random() + +print(value) # Should not error +[builtins fixtures/for_else_exception.pyi] +[typing fixtures/typing-medium.pyi] + +[case testForElseDefineInBothBranches] +# flags: --enable-error-code possibly-undefined +# Test that variable defined in both for break and else branches is not undefined +for i in range(10): + if i: + value = i + break +else: + value = 1 + +print(value) # Should not error - value is defined in all paths +[builtins fixtures/for_else_exception.pyi] +[typing fixtures/typing-medium.pyi] + +[case testForElseWithWalrusInBreak] +# flags: --enable-error-code possibly-undefined +# Test with walrus operator in if condition before break +def random() -> float: return 0.5 + +if random(): + for i in range(10): + if value := random(): + break + else: + raise RuntimeError + + print(value) # Should not error - value is defined if we broke +else: + value = random() + +print(value) # Should not error +[builtins fixtures/for_else_exception.pyi] +[typing fixtures/typing-medium.pyi] diff --git a/test-data/unit/fixtures/for_else_exception.pyi b/test-data/unit/fixtures/for_else_exception.pyi new file mode 100644 index 000000000000..98c953caff72 --- /dev/null +++ b/test-data/unit/fixtures/for_else_exception.pyi @@ -0,0 +1,54 @@ +# Fixture for for-else tests with exceptions +# Combines needed elements from primitives.pyi and exception.pyi + +from typing import Generic, Iterator, Mapping, Sequence, TypeVar + +T = TypeVar('T') +V = TypeVar('V') + +class object: + def __init__(self) -> None: pass +class type: + def __init__(self, x: object) -> None: pass +class int: + def __init__(self, x: object = ..., base: int = ...) -> None: pass + def __add__(self, i: int) -> int: pass + def __rmul__(self, x: int) -> int: pass + def __bool__(self) -> bool: pass + def __eq__(self, x: object) -> bool: pass + def __ne__(self, x: object) -> bool: pass + def __lt__(self, x: 'int') -> bool: pass + def __le__(self, x: 'int') -> bool: pass + def __gt__(self, x: 'int') -> bool: pass + def __ge__(self, x: 'int') -> bool: pass +class float: + def __float__(self) -> float: pass + def __add__(self, x: float) -> float: pass + def hex(self) -> str: pass +class bool(int): pass +class str(Sequence[str]): + def __add__(self, s: str) -> str: pass + def __iter__(self) -> Iterator[str]: pass + def __contains__(self, other: object) -> bool: pass + def __getitem__(self, item: int) -> str: pass + def format(self, *args: object, **kwargs: object) -> str: pass +class dict(Mapping[T, V]): + def __iter__(self) -> Iterator[T]: pass +class tuple(Generic[T]): + def __contains__(self, other: object) -> bool: pass +class ellipsis: pass + +class BaseException: + def __init__(self, *args: object) -> None: ... +class Exception(BaseException): pass +class RuntimeError(Exception): pass + +class range(Sequence[int]): + def __init__(self, __x: int, __y: int = ..., __z: int = ...) -> None: pass + def count(self, value: int) -> int: pass + def index(self, value: int) -> int: pass + def __getitem__(self, i: int) -> int: pass + def __iter__(self) -> Iterator[int]: pass + def __contains__(self, other: object) -> bool: pass + +def print(x: object) -> None: pass