Skip to content

Commit ce69b76

Browse files
committed
Fix spurious possibly-undefined errors in for-else with break
When a for loop contains branches with `break` and an `else` block, variables declared inside those branches were incorrectly discarded from further analysis, leading Mypy to incorrectly report a variable as undefined after the loop or as used before declaration. With this fix, when a for loop's `else` block is considered, variables declared in every branch of the `for` loop body that called `break` are now considered as defined within the body of the loop. Fixes #14209 Fixes #19690
1 parent 657bdd8 commit ce69b76

File tree

3 files changed

+246
-0
lines changed

3 files changed

+246
-0
lines changed

mypy/partially_defined.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@
4848
from mypy.types import Type, UninhabitedType, get_proper_type
4949

5050

51+
def _ambv(s: str) -> None:
52+
assert s
53+
pass # print("DEBUG:", s)
54+
55+
56+
def _ambv_cont(s: str) -> None:
57+
assert s
58+
pass # print(s)
59+
60+
5161
class BranchState:
5262
"""BranchState contains information about variable definition at the end of a branching statement.
5363
`if` and `match` are examples of branching statements.
@@ -117,6 +127,9 @@ def delete_var(self, name: str) -> None:
117127
def record_nested_branch(self, state: BranchState) -> None:
118128
assert len(self.branches) > 0
119129
current_branch = self.branches[-1]
130+
_ambv(
131+
f"record_nested_branch: state.must={state.must_be_defined if 'value' in state.must_be_defined else '...'}, state.may={'value' if 'value' in state.may_be_defined else '...'}, state.skipped={state.skipped}"
132+
)
120133
if state.skipped:
121134
current_branch.skipped = True
122135
return
@@ -154,6 +167,17 @@ def done(self) -> BranchState:
154167
all_vars.update(b.must_be_defined)
155168
# For the rest of the things, we only care about branches that weren't skipped.
156169
non_skipped_branches = [b for b in self.branches if not b.skipped]
170+
import sys
171+
172+
_called_by = sys._getframe(2).f_code.co_name
173+
_ambv(
174+
f"done {_called_by}: branches={len(self.branches)}, non_skipped={len(non_skipped_branches)}"
175+
)
176+
for i, b in enumerate(self.branches):
177+
has_value = "value" in b.must_be_defined or "value" in b.may_be_defined
178+
_ambv_cont(
179+
f" Branch {i}: has_value={has_value}, skipped={b.skipped}, must={b.must_be_defined}, may={b.may_be_defined}"
180+
)
157181
if non_skipped_branches:
158182
must_be_defined = non_skipped_branches[0].must_be_defined
159183
for b in non_skipped_branches[1:]:
@@ -163,6 +187,7 @@ def done(self) -> BranchState:
163187
# Everything that wasn't defined in all branches but was defined
164188
# in at least one branch should be in `may_be_defined`!
165189
may_be_defined = all_vars.difference(must_be_defined)
190+
_ambv_cont(f" Result: must={must_be_defined}, may={may_be_defined}")
166191
return BranchState(
167192
must_be_defined=must_be_defined,
168193
may_be_defined=may_be_defined,
@@ -295,6 +320,8 @@ def is_undefined(self, name: str) -> bool:
295320
class Loop:
296321
def __init__(self) -> None:
297322
self.has_break = False
323+
# variables defined in every loop branch with `break`
324+
self.break_vars: set[str] | None = None
298325

299326

300327
class PossiblyUndefinedVariableVisitor(ExtendedTraverserVisitor):
@@ -336,6 +363,10 @@ def __init__(
336363
for name in implicit_module_attrs:
337364
self.tracker.record_definition(name)
338365

366+
# def visit_block(self, block: Block, /) -> None:
367+
# _ambv(f"PossiblyUndefinedVariableVisitor visiting {block}")
368+
# super().visit_block(block)
369+
339370
def var_used_before_def(self, name: str, context: Context) -> None:
340371
if self.msg.errors.is_error_code_enabled(errorcodes.USED_BEFORE_DEF):
341372
self.msg.var_used_before_def(name, context)
@@ -349,6 +380,9 @@ def process_definition(self, name: str) -> None:
349380
if not self.tracker.in_scope(ScopeType.Class):
350381
refs = self.tracker.pop_undefined_ref(name)
351382
for ref in refs:
383+
_ambv(
384+
f"process_definition for {name}, ref at line {ref.line}, loops={bool(self.loops)}"
385+
)
352386
if self.loops:
353387
self.variable_may_be_undefined(name, ref)
354388
else:
@@ -370,6 +404,9 @@ def visit_nonlocal_decl(self, o: NonlocalDecl) -> None:
370404

371405
def process_lvalue(self, lvalue: Lvalue | None) -> None:
372406
if isinstance(lvalue, NameExpr):
407+
_ambv(
408+
f"process_lvalue calling process_definition for {lvalue.name} at line {lvalue.line}"
409+
)
373410
self.process_definition(lvalue.name)
374411
elif isinstance(lvalue, StarExpr):
375412
self.process_lvalue(lvalue.expr)
@@ -378,6 +415,7 @@ def process_lvalue(self, lvalue: Lvalue | None) -> None:
378415
self.process_lvalue(item)
379416

380417
def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
418+
_ambv(f"visit_assignment_stmt at line {o.line}")
381419
for lvalue in o.lvalues:
382420
self.process_lvalue(lvalue)
383421
super().visit_assignment_stmt(o)
@@ -456,22 +494,39 @@ def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None:
456494
self.tracker.exit_scope()
457495

458496
def visit_for_stmt(self, o: ForStmt) -> None:
497+
_ambv(f"visit_for_stmt: line {o.line}")
459498
o.expr.accept(self)
460499
self.process_lvalue(o.index)
461500
o.index.accept(self)
462501
self.tracker.start_branch_statement()
463502
loop = Loop()
464503
self.loops.append(loop)
504+
_ambv(
505+
f"visit_for_stmt: Before body, current state: {self.tracker._scope().branch_stmts[-1].branches[-1].must_be_defined}, may={self.tracker._scope().branch_stmts[-1].branches[-1].may_be_defined}"
506+
)
465507
o.body.accept(self)
508+
_ambv(f"visit_for_stmt: after body, has_break={loop.has_break}")
509+
_ambv(
510+
f"visit_for_stmt: After body, current state: {self.tracker._scope().branch_stmts[-1].branches[-1].must_be_defined}, may={self.tracker._scope().branch_stmts[-1].branches[-1].may_be_defined}"
511+
)
466512
self.tracker.next_branch()
513+
_ambv(
514+
f"visit_for_stmt: After next_branch, new branch state: {self.tracker._scope().branch_stmts[-1].branches[-1].must_be_defined}, may={self.tracker._scope().branch_stmts[-1].branches[-1].may_be_defined}"
515+
)
467516
self.tracker.end_branch_statement()
468517
if o.else_body is not None:
469518
# If the loop has a `break` inside, `else` is executed conditionally.
470519
# If the loop doesn't have a `break` either the function will return or
471520
# execute the `else`.
472521
has_break = loop.has_break
522+
_ambv(
523+
f"visit_for_stmt: else_body present, has_break={has_break}, break_vars={loop.break_vars}"
524+
)
473525
if has_break:
474526
self.tracker.start_branch_statement()
527+
if loop.break_vars is not None:
528+
for bv in loop.break_vars:
529+
self.tracker.record_definition(bv)
475530
self.tracker.next_branch()
476531
o.else_body.accept(self)
477532
if has_break:
@@ -497,13 +552,22 @@ def visit_raise_stmt(self, o: RaiseStmt) -> None:
497552
self.tracker.skip_branch()
498553

499554
def visit_continue_stmt(self, o: ContinueStmt) -> None:
555+
_ambv(f"continue at line {o.line}, skipping branch")
500556
super().visit_continue_stmt(o)
501557
self.tracker.skip_branch()
502558

503559
def visit_break_stmt(self, o: BreakStmt) -> None:
504560
super().visit_break_stmt(o)
505561
if self.loops:
506562
self.loops[-1].has_break = True
563+
# Track variables that are definitely defined at the point of break
564+
if len(self.tracker._scope().branch_stmts) > 0:
565+
branch = self.tracker._scope().branch_stmts[-1].branches[-1]
566+
if self.loops[-1].break_vars is None:
567+
self.loops[-1].break_vars = set(branch.must_be_defined)
568+
else:
569+
# we only want variables that have been defined in each branch
570+
self.loops[-1].break_vars.intersection_update(branch.must_be_defined)
507571
self.tracker.skip_branch()
508572

509573
def visit_expression_stmt(self, o: ExpressionStmt) -> None:
@@ -545,6 +609,7 @@ def f() -> int:
545609
self.try_depth -= 1
546610

547611
def process_try_stmt(self, o: TryStmt) -> None:
612+
_ambv(f"process_try_stmt: line {o.line}, handlers={len(o.handlers)}")
548613
"""
549614
Processes try statement decomposing it into the following:
550615
if ...:
@@ -620,6 +685,9 @@ def visit_starred_pattern(self, o: StarredPattern) -> None:
620685
def visit_name_expr(self, o: NameExpr) -> None:
621686
if o.name in self.builtins and self.tracker.in_scope(ScopeType.Global):
622687
return
688+
_ambv(
689+
f"visit_name_expr {o.name} at line {o.line}, possibly_undefined={self.tracker.is_possibly_undefined(o.name)}, defined_in_different_branch={self.tracker.is_defined_in_different_branch(o.name)}, is_undefined={self.tracker.is_undefined(o.name)}"
690+
)
623691
if self.tracker.is_possibly_undefined(o.name):
624692
# A variable is only defined in some branches.
625693
self.variable_may_be_undefined(o.name, o)
@@ -640,6 +708,7 @@ def visit_name_expr(self, o: NameExpr) -> None:
640708
# Case (1) will be caught by semantic analyzer. Case (2) is a forward ref that should
641709
# be caught by this visitor. Save the ref for later, so that if we see a definition,
642710
# we know it's a used-before-definition scenario.
711+
_ambv(f"Recording undefined ref for {o.name} at line {o.line}")
643712
self.tracker.record_undefined_ref(o)
644713
super().visit_name_expr(o)
645714

test-data/unit/check-possibly-undefined.test

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,3 +1043,126 @@ def foo(x: Union[int, str]) -> None:
10431043
assert_never(x)
10441044
f # OK
10451045
[builtins fixtures/tuple.pyi]
1046+
1047+
[case testForElseWithBreakInTryExceptContinue]
1048+
# flags: --enable-error-code possibly-undefined
1049+
# Test for issue where variable defined before break in try block
1050+
# was incorrectly reported as undefined when except has continue
1051+
def random() -> float: return 0.5
1052+
1053+
if random():
1054+
for i in range(10):
1055+
try:
1056+
value = random()
1057+
break
1058+
except Exception:
1059+
continue
1060+
else:
1061+
raise RuntimeError
1062+
1063+
print(value) # Should not error - value is defined if we broke
1064+
else:
1065+
value = random()
1066+
1067+
print(value) # Should not error
1068+
[builtins fixtures/for_else_exception.pyi]
1069+
[typing fixtures/typing-medium.pyi]
1070+
1071+
[case testForElseWithBreakInTryExceptContinueNoIf]
1072+
# flags: --enable-error-code possibly-undefined
1073+
# Simpler version without if statement
1074+
def random() -> float: return 0.5
1075+
1076+
for i in range(10):
1077+
try:
1078+
value = random()
1079+
break
1080+
except Exception:
1081+
continue
1082+
else:
1083+
raise RuntimeError
1084+
1085+
print(value) # Should not error
1086+
[builtins fixtures/for_else_exception.pyi]
1087+
[typing fixtures/typing-medium.pyi]
1088+
1089+
[case testForElseWithBreakInTryExceptPass]
1090+
# flags: --enable-error-code possibly-undefined
1091+
# Version with pass instead of continue - should also work
1092+
def random() -> float: return 0.5
1093+
1094+
if random():
1095+
for i in range(10):
1096+
try:
1097+
value = random()
1098+
break
1099+
except Exception:
1100+
pass
1101+
else:
1102+
raise RuntimeError
1103+
1104+
print(value) # Should not error
1105+
else:
1106+
value = random()
1107+
1108+
print(value) # Should not error
1109+
[builtins fixtures/for_else_exception.pyi]
1110+
[typing fixtures/typing-medium.pyi]
1111+
1112+
[case testForElseWithConditionalDefBeforeBreak]
1113+
# flags: --enable-error-code possibly-undefined
1114+
# Test that conditional definition before break still works correctly
1115+
def random() -> float: return 0.5
1116+
1117+
if random():
1118+
for i in range(10):
1119+
try:
1120+
if i > 10:
1121+
value = random()
1122+
break
1123+
except Exception:
1124+
continue
1125+
else:
1126+
raise RuntimeError
1127+
1128+
print(value) # Should not error (though might be undefined at runtime)
1129+
else:
1130+
value = random()
1131+
1132+
print(value) # Should not error
1133+
[builtins fixtures/for_else_exception.pyi]
1134+
[typing fixtures/typing-medium.pyi]
1135+
1136+
[case testForElseDefineInBothBranches]
1137+
# flags: --enable-error-code possibly-undefined
1138+
# Test that variable defined in both for break and else branches is not undefined
1139+
for i in range(10):
1140+
if i:
1141+
value = i
1142+
break
1143+
else:
1144+
value = 1
1145+
1146+
print(value) # Should not error - value is defined in all paths
1147+
[builtins fixtures/for_else_exception.pyi]
1148+
[typing fixtures/typing-medium.pyi]
1149+
1150+
[case testForElseWithWalrusInBreak]
1151+
# flags: --enable-error-code possibly-undefined
1152+
# Test with walrus operator in if condition before break
1153+
def random() -> float: return 0.5
1154+
1155+
if random():
1156+
for i in range(10):
1157+
if value := random():
1158+
break
1159+
else:
1160+
raise RuntimeError
1161+
1162+
print(value) # Should not error - value is defined if we broke
1163+
else:
1164+
value = random()
1165+
1166+
print(value) # Should not error
1167+
[builtins fixtures/for_else_exception.pyi]
1168+
[typing fixtures/typing-medium.pyi]
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Fixture for for-else tests with exceptions
2+
# Combines needed elements from primitives.pyi and exception.pyi
3+
4+
from typing import Generic, Iterator, Mapping, Sequence, TypeVar
5+
6+
T = TypeVar('T')
7+
V = TypeVar('V')
8+
9+
class object:
10+
def __init__(self) -> None: pass
11+
class type:
12+
def __init__(self, x: object) -> None: pass
13+
class int:
14+
def __init__(self, x: object = ..., base: int = ...) -> None: pass
15+
def __add__(self, i: int) -> int: pass
16+
def __rmul__(self, x: int) -> int: pass
17+
def __bool__(self) -> bool: pass
18+
def __eq__(self, x: object) -> bool: pass
19+
def __ne__(self, x: object) -> bool: pass
20+
def __lt__(self, x: 'int') -> bool: pass
21+
def __le__(self, x: 'int') -> bool: pass
22+
def __gt__(self, x: 'int') -> bool: pass
23+
def __ge__(self, x: 'int') -> bool: pass
24+
class float:
25+
def __float__(self) -> float: pass
26+
def __add__(self, x: float) -> float: pass
27+
def hex(self) -> str: pass
28+
class bool(int): pass
29+
class str(Sequence[str]):
30+
def __add__(self, s: str) -> str: pass
31+
def __iter__(self) -> Iterator[str]: pass
32+
def __contains__(self, other: object) -> bool: pass
33+
def __getitem__(self, item: int) -> str: pass
34+
def format(self, *args: object, **kwargs: object) -> str: pass
35+
class dict(Mapping[T, V]):
36+
def __iter__(self) -> Iterator[T]: pass
37+
class tuple(Generic[T]):
38+
def __contains__(self, other: object) -> bool: pass
39+
class ellipsis: pass
40+
41+
class BaseException:
42+
def __init__(self, *args: object) -> None: ...
43+
class Exception(BaseException): pass
44+
class RuntimeError(Exception): pass
45+
46+
class range(Sequence[int]):
47+
def __init__(self, __x: int, __y: int = ..., __z: int = ...) -> None: pass
48+
def count(self, value: int) -> int: pass
49+
def index(self, value: int) -> int: pass
50+
def __getitem__(self, i: int) -> int: pass
51+
def __iter__(self) -> Iterator[int]: pass
52+
def __contains__(self, other: object) -> bool: pass
53+
54+
def print(x: object) -> None: pass

0 commit comments

Comments
 (0)