Skip to content

Commit b3ec408

Browse files
committed
Create fake named expressions for match subject in more cases
1 parent 1affabe commit b3ec408

File tree

2 files changed

+130
-14
lines changed

2 files changed

+130
-14
lines changed

mypy/checker.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,12 @@
6565
CallExpr,
6666
ClassDef,
6767
ComparisonExpr,
68+
ComplexExpr,
6869
Context,
6970
ContinueStmt,
7071
Decorator,
7172
DelStmt,
73+
DictExpr,
7274
EllipsisExpr,
7375
Expression,
7476
ExpressionStmt,
@@ -100,6 +102,7 @@
100102
RaiseStmt,
101103
RefExpr,
102104
ReturnStmt,
105+
SetExpr,
103106
StarExpr,
104107
Statement,
105108
StrExpr,
@@ -350,6 +353,9 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
350353
# functions such as open(), etc.
351354
plugin: Plugin
352355

356+
# A helper state to produce unique temporary names on demand.
357+
_unique_id: int
358+
353359
def __init__(
354360
self,
355361
errors: Errors,
@@ -413,6 +419,7 @@ def __init__(
413419
self, self.msg, self.plugin, per_line_checking_time_ns
414420
)
415421
self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options)
422+
self._unique_id = 0
416423

417424
@property
418425
def type_context(self) -> list[Type | None]:
@@ -5273,19 +5280,7 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None:
52735280
return
52745281

52755282
def visit_match_stmt(self, s: MatchStmt) -> None:
5276-
named_subject: Expression
5277-
if isinstance(s.subject, CallExpr):
5278-
# Create a dummy subject expression to handle cases where a match statement's subject
5279-
# is not a literal value. This lets us correctly narrow types and check exhaustivity
5280-
# This is hack!
5281-
id = s.subject.callee.fullname if isinstance(s.subject.callee, RefExpr) else ""
5282-
name = "dummy-match-" + id
5283-
v = Var(name)
5284-
named_subject = NameExpr(name)
5285-
named_subject.node = v
5286-
else:
5287-
named_subject = s.subject
5288-
5283+
named_subject = self._make_named_statement_for_match(s.subject)
52895284
with self.binder.frame_context(can_skip=False, fall_through=0):
52905285
subject_type = get_proper_type(self.expr_checker.accept(s.subject))
52915286

@@ -5362,6 +5357,38 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
53625357
with self.binder.frame_context(can_skip=False, fall_through=2):
53635358
pass
53645359

5360+
def _make_named_statement_for_match(self, subject: Expression) -> Expression:
5361+
"""Construct a fake NameExpr for inference if a match clause is complex."""
5362+
expressions_to_preserve = (
5363+
# Already named - we should infer type of it as given
5364+
NameExpr,
5365+
AssignmentExpr,
5366+
# Collection literals defined inline - we want to infer types of variables
5367+
# included there, not exprs as a whole
5368+
ListExpr,
5369+
DictExpr,
5370+
TupleExpr,
5371+
SetExpr,
5372+
# Primitive literals - their type is known, no need to name them
5373+
IntExpr,
5374+
StrExpr,
5375+
BytesExpr,
5376+
FloatExpr,
5377+
ComplexExpr,
5378+
EllipsisExpr,
5379+
)
5380+
if isinstance(subject, expressions_to_preserve):
5381+
return subject
5382+
else:
5383+
# Create a dummy subject expression to handle cases where a match statement's subject
5384+
# is not a literal value. This lets us correctly narrow types and check exhaustivity
5385+
# This is hack!
5386+
name = self.new_unique_dummy_name("match")
5387+
v = Var(name)
5388+
named_subject = NameExpr(name)
5389+
named_subject.node = v
5390+
return named_subject
5391+
53655392
def _get_recursive_sub_patterns_map(
53665393
self, expr: Expression, typ: Type
53675394
) -> dict[Expression, Type]:
@@ -7715,6 +7742,12 @@ def warn_deprecated_overload_item(
77157742
if candidate == target:
77167743
self.warn_deprecated(item.func, context)
77177744

7745+
def new_unique_dummy_name(self, namespace: str) -> str:
7746+
"""Generate a name that is guaranteed to be unique for this TypeChecker instance."""
7747+
name = f"dummy-{namespace}-{self._unique_id}"
7748+
self._unique_id += 1
7749+
return name
7750+
77187751

77197752
class CollectArgTypeVarTypes(TypeTraverserVisitor):
77207753
"""Collects the non-nested argument types in a set."""

test-data/unit/check-python310.test

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1239,7 +1239,7 @@ def main() -> None:
12391239
case a:
12401240
reveal_type(a) # N: Revealed type is "builtins.int"
12411241

1242-
[case testMatchCapturePatternFromAsyncFunctionReturningUnion-xfail]
1242+
[case testMatchCapturePatternFromAsyncFunctionReturningUnion]
12431243
async def func1(arg: bool) -> str | int: ...
12441244
async def func2(arg: bool) -> bytes | int: ...
12451245

@@ -2439,3 +2439,86 @@ def foo(x: T) -> T:
24392439
return out
24402440

24412441
[builtins fixtures/isinstance.pyi]
2442+
2443+
[case testMatchFunctionCall]
2444+
# flags: --warn-unreachable
2445+
2446+
def fn() -> int | str: ...
2447+
2448+
match fn():
2449+
case str(s):
2450+
reveal_type(s) # N: Revealed type is "builtins.str"
2451+
case int(i):
2452+
reveal_type(i) # N: Revealed type is "builtins.int"
2453+
case other:
2454+
other # E: Statement is unreachable
2455+
2456+
[case testMatchAttribute]
2457+
# flags: --warn-unreachable
2458+
2459+
class A:
2460+
foo: int | str
2461+
2462+
match A().foo:
2463+
case str(s):
2464+
reveal_type(s) # N: Revealed type is "builtins.str"
2465+
case int(i):
2466+
reveal_type(i) # N: Revealed type is "builtins.int"
2467+
case other:
2468+
other # E: Statement is unreachable
2469+
2470+
[case testMatchOperations]
2471+
# flags: --warn-unreachable
2472+
2473+
x: int
2474+
match -x:
2475+
case -1 as s:
2476+
reveal_type(s) # N: Revealed type is "Literal[-1]"
2477+
case int(s):
2478+
reveal_type(s) # N: Revealed type is "builtins.int"
2479+
case other:
2480+
other # E: Statement is unreachable
2481+
2482+
match 1 + 2:
2483+
case 3 as s:
2484+
reveal_type(s) # N: Revealed type is "Literal[3]"
2485+
case int(s):
2486+
reveal_type(s) # N: Revealed type is "builtins.int"
2487+
case other:
2488+
other # E: Statement is unreachable
2489+
2490+
match 1 > 2:
2491+
case True as s:
2492+
reveal_type(s) # N: Revealed type is "Literal[True]"
2493+
case False as s:
2494+
reveal_type(s) # N: Revealed type is "Literal[False]"
2495+
case other:
2496+
other # E: Statement is unreachable
2497+
[builtins fixtures/ops.pyi]
2498+
2499+
[case testMatchDictItem]
2500+
# flags: --warn-unreachable
2501+
2502+
m: dict[str, int | str]
2503+
k: str
2504+
2505+
match m[k]:
2506+
case str(s):
2507+
reveal_type(s) # N: Revealed type is "builtins.str"
2508+
case int(i):
2509+
reveal_type(i) # N: Revealed type is "builtins.int"
2510+
case other:
2511+
other # E: Statement is unreachable
2512+
2513+
[builtins fixtures/dict.pyi]
2514+
2515+
[case testMatchLiteralValuePathological]
2516+
# flags: --warn-unreachable
2517+
2518+
match 0:
2519+
case 0 as i:
2520+
reveal_type(i) # N: Revealed type is "Literal[0]?"
2521+
case int(i):
2522+
i # E: Statement is unreachable
2523+
case other:
2524+
other # E: Statement is unreachable

0 commit comments

Comments
 (0)