Skip to content

Commit c51ec9b

Browse files
committed
Improve match inline list subject inference
1 parent e852829 commit c51ec9b

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

mypy/fastparse.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def __init__(
396396
# 'C' for class, 'D' for function signature, 'F' for function, 'L' for lambda
397397
self.class_and_function_stack: list[Literal["C", "D", "F", "L"]] = []
398398
self.imports: list[ImportBase] = []
399+
self.match_stmt_subject = False
399400

400401
self.options = options
401402
self.is_stub = is_stub
@@ -1758,7 +1759,7 @@ def visit_Name(self, n: Name) -> NameExpr:
17581759
# List(expr* elts, expr_context ctx)
17591760
def visit_List(self, n: ast3.List) -> ListExpr | TupleExpr:
17601761
expr_list: list[Expression] = [self.visit(e) for e in n.elts]
1761-
if isinstance(n.ctx, ast3.Store):
1762+
if isinstance(n.ctx, ast3.Store) or self.match_stmt_subject:
17621763
# [x, y] = z and (x, y) = z means exactly the same thing
17631764
e: ListExpr | TupleExpr = TupleExpr(expr_list)
17641765
else:
@@ -1779,8 +1780,11 @@ def visit_Slice(self, n: ast3.Slice) -> SliceExpr:
17791780

17801781
# Match(expr subject, match_case* cases) # python 3.10 and later
17811782
def visit_Match(self, n: Match) -> MatchStmt:
1783+
self.match_stmt_subject = True
1784+
subject = self.visit(n.subject)
1785+
self.match_stmt_subject = False
17821786
node = MatchStmt(
1783-
self.visit(n.subject),
1787+
subject,
17841788
[self.visit(c.pattern) for c in n.cases],
17851789
[self.visit(c.guard) for c in n.cases],
17861790
[self.as_required_block(c.body) for c in n.cases],

test-data/unit/check-python310.test

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ match x:
318318
pass
319319

320320
[case testMatchSequencePatternWithInvalidClassPattern]
321+
# flags: --warn-unreachable
321322
class Example:
322323
__match_args__ = ("value",)
323324
def __init__(self, value: str) -> None:
@@ -327,11 +328,32 @@ SubClass: type[Example]
327328

328329
match [SubClass("a"), SubClass("b")]:
329330
case [SubClass(value), *rest]: # E: Expected type in class pattern; found "type[__main__.Example]"
330-
reveal_type(value) # E: Cannot determine type of "value" \
331-
# N: Revealed type is "Any"
331+
reveal_type(value) # E: Statement is unreachable
332+
reveal_type(rest)
333+
case [Example(value), *rest]:
334+
reveal_type(value) # N: Revealed type is "builtins.str"
332335
reveal_type(rest) # N: Revealed type is "builtins.list[__main__.Example]"
333336
[builtins fixtures/tuple.pyi]
334337

338+
[case testMatchSequencePatternSequenceSubject]
339+
a: int
340+
b: str
341+
match a, b:
342+
case 1, "Hello":
343+
reveal_type(a) # N: Revealed type is "Literal[1]"
344+
reveal_type(b) # N: Revealed type is "Literal['Hello']"
345+
346+
match (a, b):
347+
case (1, "Hello"):
348+
reveal_type(a) # N: Revealed type is "Literal[1]"
349+
reveal_type(b) # N: Revealed type is "Literal['Hello']"
350+
351+
match [a, b]:
352+
case [1, "Hello"]:
353+
reveal_type(a) # N: Revealed type is "Literal[1]"
354+
reveal_type(b) # N: Revealed type is "Literal['Hello']"
355+
[builtins fixtures/tuple.pyi]
356+
335357
# Narrowing union-based values via a literal pattern on an indexed/attribute subject
336358
# -------------------------------------------------------------------------------
337359
# Literal patterns against a union of types can be used to narrow the subject

0 commit comments

Comments
 (0)