Skip to content

Commit 8844183

Browse files
committed
Merge branch 'master' into narrowing/refine_partial_types_in_loops
# Conflicts: # test-data/unit/check-narrowing.test
2 parents c8eee51 + 499adae commit 8844183

12 files changed

+273
-48
lines changed

mypy/binder.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections import defaultdict
44
from contextlib import contextmanager
5-
from typing import DefaultDict, Iterator, List, Optional, Tuple, Union, cast
5+
from typing import DefaultDict, Iterator, List, NamedTuple, Optional, Tuple, Union
66
from typing_extensions import TypeAlias as _TypeAlias
77

88
from mypy.erasetype import remove_instance_last_known_values
@@ -30,6 +30,11 @@
3030
BindableExpression: _TypeAlias = Union[IndexExpr, MemberExpr, NameExpr]
3131

3232

33+
class CurrentType(NamedTuple):
34+
type: Type
35+
from_assignment: bool
36+
37+
3338
class Frame:
3439
"""A Frame represents a specific point in the execution of a program.
3540
It carries information about the current types of expressions at
@@ -44,7 +49,7 @@ class Frame:
4449

4550
def __init__(self, id: int, conditional_frame: bool = False) -> None:
4651
self.id = id
47-
self.types: dict[Key, Type] = {}
52+
self.types: dict[Key, CurrentType] = {}
4853
self.unreachable = False
4954
self.conditional_frame = conditional_frame
5055
self.suppress_unreachable_warnings = False
@@ -132,18 +137,18 @@ def push_frame(self, conditional_frame: bool = False) -> Frame:
132137
self.options_on_return.append([])
133138
return f
134139

135-
def _put(self, key: Key, type: Type, index: int = -1) -> None:
136-
self.frames[index].types[key] = type
140+
def _put(self, key: Key, type: Type, from_assignment: bool, index: int = -1) -> None:
141+
self.frames[index].types[key] = CurrentType(type, from_assignment)
137142

138-
def _get(self, key: Key, index: int = -1) -> Type | None:
143+
def _get(self, key: Key, index: int = -1) -> CurrentType | None:
139144
if index < 0:
140145
index += len(self.frames)
141146
for i in range(index, -1, -1):
142147
if key in self.frames[i].types:
143148
return self.frames[i].types[key]
144149
return None
145150

146-
def put(self, expr: Expression, typ: Type) -> None:
151+
def put(self, expr: Expression, typ: Type, *, from_assignment: bool = True) -> None:
147152
if not isinstance(expr, (IndexExpr, MemberExpr, NameExpr)):
148153
return
149154
if not literal(expr):
@@ -153,7 +158,7 @@ def put(self, expr: Expression, typ: Type) -> None:
153158
if key not in self.declarations:
154159
self.declarations[key] = get_declaration(expr)
155160
self._add_dependencies(key)
156-
self._put(key, typ)
161+
self._put(key, typ, from_assignment)
157162

158163
def unreachable(self) -> None:
159164
self.frames[-1].unreachable = True
@@ -164,7 +169,10 @@ def suppress_unreachable_warnings(self) -> None:
164169
def get(self, expr: Expression) -> Type | None:
165170
key = literal_hash(expr)
166171
assert key is not None, "Internal error: binder tried to get non-literal"
167-
return self._get(key)
172+
found = self._get(key)
173+
if found is None:
174+
return None
175+
return found.type
168176

169177
def is_unreachable(self) -> bool:
170178
# TODO: Copy the value of unreachable into new frames to avoid
@@ -193,7 +201,7 @@ def update_from_options(self, frames: list[Frame]) -> bool:
193201
If a key is declared as AnyType, only update it if all the
194202
options are the same.
195203
"""
196-
204+
all_reachable = all(not f.unreachable for f in frames)
197205
frames = [f for f in frames if not f.unreachable]
198206
changed = False
199207
keys = {key for f in frames for key in f.types}
@@ -207,17 +215,30 @@ def update_from_options(self, frames: list[Frame]) -> bool:
207215
# know anything about key in at least one possible frame.
208216
continue
209217

210-
type = resulting_values[0]
211-
assert type is not None
218+
if all_reachable and all(
219+
x is not None and not x.from_assignment for x in resulting_values
220+
):
221+
# Do not synthesize a new type if we encountered a conditional block
222+
# (if, while or match-case) without assignments.
223+
# See check-isinstance.test::testNoneCheckDoesNotMakeTypeVarOptional
224+
# This is a safe assumption: the fact that we checked something with `is`
225+
# or `isinstance` does not change the type of the value.
226+
continue
227+
228+
current_type = resulting_values[0]
229+
assert current_type is not None
230+
type = current_type.type
212231
declaration_type = get_proper_type(self.declarations.get(key))
213232
if isinstance(declaration_type, AnyType):
214233
# At this point resulting values can't contain None, see continue above
215-
if not all(is_same_type(type, cast(Type, t)) for t in resulting_values[1:]):
234+
if not all(
235+
t is not None and is_same_type(type, t.type) for t in resulting_values[1:]
236+
):
216237
type = AnyType(TypeOfAny.from_another_any, source_any=declaration_type)
217238
else:
218239
for other in resulting_values[1:]:
219240
assert other is not None
220-
type = join_simple(self.declarations[key], type, other)
241+
type = join_simple(self.declarations[key], type, other.type)
221242
# Try simplifying resulting type for unions involving variadic tuples.
222243
# Technically, everything is still valid without this step, but if we do
223244
# not do this, this may create long unions after exiting an if check like:
@@ -236,8 +257,8 @@ def update_from_options(self, frames: list[Frame]) -> bool:
236257
)
237258
if simplified == self.declarations[key]:
238259
type = simplified
239-
if current_value is None or not is_same_type(type, current_value):
240-
self._put(key, type)
260+
if current_value is None or not is_same_type(type, current_value[0]):
261+
self._put(key, type, from_assignment=True)
241262
changed = True
242263

243264
self.frames[-1].unreachable = not frames
@@ -374,7 +395,9 @@ def most_recent_enclosing_type(self, expr: BindableExpression, type: Type) -> Ty
374395
key = literal_hash(expr)
375396
assert key is not None
376397
enclosers = [get_declaration(expr)] + [
377-
f.types[key] for f in self.frames if key in f.types and is_subtype(type, f.types[key])
398+
f.types[key].type
399+
for f in self.frames
400+
if key in f.types and is_subtype(type, f.types[key][0])
378401
]
379402
return enclosers[-1]
380403

mypy/checker.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4728,11 +4728,11 @@ def visit_if_stmt(self, s: IfStmt) -> None:
47284728

47294729
# XXX Issue a warning if condition is always False?
47304730
with self.binder.frame_context(can_skip=True, fall_through=2):
4731-
self.push_type_map(if_map)
4731+
self.push_type_map(if_map, from_assignment=False)
47324732
self.accept(b)
47334733

47344734
# XXX Issue a warning if condition is always True?
4735-
self.push_type_map(else_map)
4735+
self.push_type_map(else_map, from_assignment=False)
47364736

47374737
with self.binder.frame_context(can_skip=False, fall_through=2):
47384738
if s.else_body:
@@ -5313,18 +5313,21 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
53135313
if b.is_unreachable or isinstance(
53145314
get_proper_type(pattern_type.type), UninhabitedType
53155315
):
5316-
self.push_type_map(None)
5316+
self.push_type_map(None, from_assignment=False)
53175317
else_map: TypeMap = {}
53185318
else:
53195319
pattern_map, else_map = conditional_types_to_typemaps(
53205320
named_subject, pattern_type.type, pattern_type.rest_type
53215321
)
53225322
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
5323-
self.push_type_map(pattern_map)
5323+
self.push_type_map(pattern_map, from_assignment=False)
53245324
if pattern_map:
53255325
for expr, typ in pattern_map.items():
5326-
self.push_type_map(self._get_recursive_sub_patterns_map(expr, typ))
5327-
self.push_type_map(pattern_type.captures)
5326+
self.push_type_map(
5327+
self._get_recursive_sub_patterns_map(expr, typ),
5328+
from_assignment=False,
5329+
)
5330+
self.push_type_map(pattern_type.captures, from_assignment=False)
53285331
if g is not None:
53295332
with self.binder.frame_context(can_skip=False, fall_through=3):
53305333
gt = get_proper_type(self.expr_checker.accept(g))
@@ -5350,11 +5353,11 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
53505353
continue
53515354
type_map[named_subject] = type_map[expr]
53525355

5353-
self.push_type_map(guard_map)
5356+
self.push_type_map(guard_map, from_assignment=False)
53545357
self.accept(b)
53555358
else:
53565359
self.accept(b)
5357-
self.push_type_map(else_map)
5360+
self.push_type_map(else_map, from_assignment=False)
53585361

53595362
# This is needed due to a quirk in frame_context. Without it types will stay narrowed
53605363
# after the match.
@@ -7375,12 +7378,12 @@ def iterable_item_type(
73757378
def function_type(self, func: FuncBase) -> FunctionLike:
73767379
return function_type(func, self.named_type("builtins.function"))
73777380

7378-
def push_type_map(self, type_map: TypeMap) -> None:
7381+
def push_type_map(self, type_map: TypeMap, *, from_assignment: bool = True) -> None:
73797382
if type_map is None:
73807383
self.binder.unreachable()
73817384
else:
73827385
for expr, type in type_map.items():
7383-
self.binder.put(expr, type)
7386+
self.binder.put(expr, type, from_assignment=from_assignment)
73847387

73857388
def infer_issubclass_maps(self, node: CallExpr, expr: Expression) -> tuple[TypeMap, TypeMap]:
73867389
"""Infer type restrictions for an expression in issubclass call."""
@@ -7753,9 +7756,7 @@ def conditional_types(
77537756
) and is_proper_subtype(current_type, proposed_type, ignore_promotions=True):
77547757
# Expression is always of one of the types in proposed_type_ranges
77557758
return default, UninhabitedType()
7756-
elif not is_overlapping_types(
7757-
current_type, proposed_type, prohibit_none_typevar_overlap=True, ignore_promotions=True
7758-
):
7759+
elif not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
77597760
# Expression is never of any type in proposed_type_ranges
77607761
return UninhabitedType(), default
77617762
else:

mypy/semanal.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,12 @@ def __init__(
484484
# Used to pass information about current overload index to visit_func_def().
485485
self.current_overload_item: int | None = None
486486

487+
# Used to track whether currently inside an except* block. This helps
488+
# to invoke errors when continue/break/return is used inside except* block.
489+
self.inside_except_star_block: bool = False
490+
# Used to track edge case when return is still inside except* if it enters a loop
491+
self.return_stmt_inside_except_star_block: bool = False
492+
487493
# mypyc doesn't properly handle implementing an abstractproperty
488494
# with a regular attribute so we make them properties
489495
@property
@@ -511,6 +517,25 @@ def allow_unbound_tvars_set(self) -> Iterator[None]:
511517
finally:
512518
self.allow_unbound_tvars = old
513519

520+
@contextmanager
521+
def inside_except_star_block_set(
522+
self, value: bool, entering_loop: bool = False
523+
) -> Iterator[None]:
524+
old = self.inside_except_star_block
525+
self.inside_except_star_block = value
526+
527+
# Return statement would still be in except* scope if entering loops
528+
if not entering_loop:
529+
old_return_stmt_flag = self.return_stmt_inside_except_star_block
530+
self.return_stmt_inside_except_star_block = value
531+
532+
try:
533+
yield
534+
finally:
535+
self.inside_except_star_block = old
536+
if not entering_loop:
537+
self.return_stmt_inside_except_star_block = old_return_stmt_flag
538+
514539
#
515540
# Preparing module (performed before semantic analysis)
516541
#
@@ -877,7 +902,8 @@ def visit_func_def(self, defn: FuncDef) -> None:
877902
return
878903

879904
with self.scope.function_scope(defn):
880-
self.analyze_func_def(defn)
905+
with self.inside_except_star_block_set(value=False):
906+
self.analyze_func_def(defn)
881907

882908
def function_fullname(self, fullname: str) -> str:
883909
if self.current_overload_item is None:
@@ -1684,6 +1710,7 @@ def visit_decorator(self, dec: Decorator) -> None:
16841710
"abc.abstractproperty",
16851711
"functools.cached_property",
16861712
"enum.property",
1713+
"types.DynamicClassAttribute",
16871714
),
16881715
):
16891716
removed.append(i)
@@ -5263,6 +5290,8 @@ def visit_return_stmt(self, s: ReturnStmt) -> None:
52635290
self.statement = s
52645291
if not self.is_func_scope():
52655292
self.fail('"return" outside function', s)
5293+
if self.return_stmt_inside_except_star_block:
5294+
self.fail('"return" not allowed in except* block', s, serious=True)
52665295
if s.expr:
52675296
s.expr.accept(self)
52685297

@@ -5296,7 +5325,8 @@ def visit_while_stmt(self, s: WhileStmt) -> None:
52965325
self.statement = s
52975326
s.expr.accept(self)
52985327
self.loop_depth[-1] += 1
5299-
s.body.accept(self)
5328+
with self.inside_except_star_block_set(value=False, entering_loop=True):
5329+
s.body.accept(self)
53005330
self.loop_depth[-1] -= 1
53015331
self.visit_block_maybe(s.else_body)
53025332

@@ -5320,20 +5350,24 @@ def visit_for_stmt(self, s: ForStmt) -> None:
53205350
s.index_type = analyzed
53215351

53225352
self.loop_depth[-1] += 1
5323-
self.visit_block(s.body)
5353+
with self.inside_except_star_block_set(value=False, entering_loop=True):
5354+
self.visit_block(s.body)
53245355
self.loop_depth[-1] -= 1
5325-
53265356
self.visit_block_maybe(s.else_body)
53275357

53285358
def visit_break_stmt(self, s: BreakStmt) -> None:
53295359
self.statement = s
53305360
if self.loop_depth[-1] == 0:
53315361
self.fail('"break" outside loop', s, serious=True, blocker=True)
5362+
if self.inside_except_star_block:
5363+
self.fail('"break" not allowed in except* block', s, serious=True)
53325364

53335365
def visit_continue_stmt(self, s: ContinueStmt) -> None:
53345366
self.statement = s
53355367
if self.loop_depth[-1] == 0:
53365368
self.fail('"continue" outside loop', s, serious=True, blocker=True)
5369+
if self.inside_except_star_block:
5370+
self.fail('"continue" not allowed in except* block', s, serious=True)
53375371

53385372
def visit_if_stmt(self, s: IfStmt) -> None:
53395373
self.statement = s
@@ -5354,7 +5388,8 @@ def analyze_try_stmt(self, s: TryStmt, visitor: NodeVisitor[None]) -> None:
53545388
type.accept(visitor)
53555389
if var:
53565390
self.analyze_lvalue(var)
5357-
handler.accept(visitor)
5391+
with self.inside_except_star_block_set(self.inside_except_star_block or s.is_star):
5392+
handler.accept(visitor)
53585393
if s.else_body:
53595394
s.else_body.accept(visitor)
53605395
if s.finally_body:

mypy/typeanal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def analyze_type_alias(
180180
)
181181
analyzer.in_dynamic_func = in_dynamic_func
182182
analyzer.global_scope = global_scope
183-
res = type.accept(analyzer)
183+
res = analyzer.anal_type(type, nested=False)
184184
return res, analyzer.aliases_used
185185

186186

test-data/unit/check-enum.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,7 @@ elif x is Foo.C:
815815
reveal_type(x) # N: Revealed type is "Literal[__main__.Foo.C]"
816816
else:
817817
reveal_type(x) # No output here: this branch is unreachable
818-
reveal_type(x) # N: Revealed type is "__main__.Foo"
818+
reveal_type(x) # N: Revealed type is "Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B], Literal[__main__.Foo.C]]"
819819

820820
if Foo.A is x:
821821
reveal_type(x) # N: Revealed type is "Literal[__main__.Foo.A]"
@@ -825,7 +825,7 @@ elif Foo.C is x:
825825
reveal_type(x) # N: Revealed type is "Literal[__main__.Foo.C]"
826826
else:
827827
reveal_type(x) # No output here: this branch is unreachable
828-
reveal_type(x) # N: Revealed type is "__main__.Foo"
828+
reveal_type(x) # N: Revealed type is "Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B], Literal[__main__.Foo.C]]"
829829

830830
y: Foo
831831
if y is Foo.A:

test-data/unit/check-isinstance.test

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2207,23 +2207,24 @@ def foo2(x: Optional[str]) -> None:
22072207
reveal_type(x) # N: Revealed type is "builtins.str"
22082208
[builtins fixtures/isinstance.pyi]
22092209

2210-
[case testNoneCheckDoesNotNarrowWhenUsingTypeVars]
2211-
2212-
# Note: this test (and the following one) are testing checker.conditional_type_map:
2213-
# if you set the 'prohibit_none_typevar_overlap' keyword argument to False when calling
2214-
# 'is_overlapping_types', the binder will incorrectly infer that 'out' has a type of
2215-
# Union[T, None] after the if statement.
2216-
2210+
[case testNoneCheckDoesNotMakeTypeVarOptional]
22172211
from typing import TypeVar
22182212

22192213
T = TypeVar('T')
22202214

2221-
def foo(x: T) -> T:
2215+
def foo_if(x: T) -> T:
22222216
out = None
22232217
out = x
22242218
if out is None:
22252219
pass
22262220
return out
2221+
2222+
def foo_while(x: T) -> T:
2223+
out = None
2224+
out = x
2225+
while out is None:
2226+
pass
2227+
return out
22272228
[builtins fixtures/isinstance.pyi]
22282229

22292230
[case testNoneCheckDoesNotNarrowWhenUsingTypeVarsNoStrictOptional]

0 commit comments

Comments
 (0)