Skip to content

Commit f6c9aa1

Browse files
committed
Track node context
1 parent 82d15f5 commit f6c9aa1

File tree

2 files changed

+33
-17
lines changed

2 files changed

+33
-17
lines changed

mypy/checkpattern.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from mypy.maptype import map_instance_to_supertype
1515
from mypy.meet import narrow_declared_type
1616
from mypy.messages import MessageBuilder
17-
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, Var
17+
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, Node, TypeAlias, Var
1818
from mypy.options import Options
1919
from mypy.patterns import (
2020
AsPattern,
@@ -104,6 +104,8 @@ class PatternChecker(PatternVisitor[PatternType]):
104104
subject_type: Type
105105
# Type of the subject to check the (sub)pattern against
106106
type_context: list[Type]
107+
# Pattern node currently being processed
108+
node_context: list[Node]
107109
# Types that match against self instead of their __match_args__ if used as a class pattern
108110
# Filled in from self_match_type_names
109111
self_match_types: list[Type]
@@ -121,6 +123,7 @@ def __init__(
121123
self.plugin = plugin
122124

123125
self.type_context = []
126+
self.node_context = []
124127
self.self_match_types = self.generate_types_from_names(self_match_type_names)
125128
self.non_sequence_match_types = self.generate_types_from_names(
126129
non_sequence_match_type_names
@@ -129,8 +132,10 @@ def __init__(
129132

130133
def accept(self, o: Pattern, type_context: Type) -> PatternType:
131134
self.type_context.append(type_context)
135+
self.node_context.append(o)
132136
result = o.accept(self)
133137
self.type_context.pop()
138+
self.node_context.pop()
134139

135140
return result
136141

@@ -140,7 +145,12 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType:
140145
pattern_type = self.accept(o.pattern, current_type)
141146
typ, rest_type, type_map = pattern_type
142147
else:
143-
typ, rest_type, type_map = current_type, UninhabitedType(), {}
148+
typ, type_map = current_type, {}
149+
if len(self.node_context) >= 2 and isinstance(self.node_context[-2], SequencePattern):
150+
# Don't narrow rest type to Never if parent node is a sequence pattern
151+
rest_type = current_type
152+
else:
153+
rest_type = UninhabitedType()
144154

145155
if not is_uninhabited(typ) and o.name is not None:
146156
typ, _ = self.chk.conditional_types_with_intersection(
@@ -315,21 +325,7 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
315325
)
316326
)
317327
narrowed_inner_types.append(narrowed_inner_type)
318-
narrowed_ptype = get_proper_type(narrowed_inner_type)
319-
if (
320-
is_uninhabited(inner_rest_type)
321-
and isinstance(narrowed_ptype, Instance)
322-
and (
323-
narrowed_ptype.type.fullname == "builtins.dict"
324-
or narrowed_ptype.type.fullname == "builtins.list"
325-
)
326-
):
327-
# Can't narrow rest type to uninhabited
328-
# if narrowed_type is dict or list.
329-
# Those can be matched by Mapping or Sequence patterns.
330-
inner_rest_types.append(narrowed_inner_type)
331-
else:
332-
inner_rest_types.append(inner_rest_type)
328+
inner_rest_types.append(inner_rest_type)
333329
if all(not is_uninhabited(typ) for typ in narrowed_inner_types):
334330
new_type = TupleType(narrowed_inner_types, current_type.partial_fallback)
335331
else:

test-data/unit/check-python310.test

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,13 +1742,33 @@ match (m7, m7):
17421742
case (_, {"a": "2"}):
17431743
reveal_type(m7) # N: Revealed type is "builtins.dict[builtins.str, builtins.str]"
17441744

1745+
match (m7, m7):
1746+
case (dict(), _):
1747+
reveal_type(m7) # N: Revealed type is "builtins.dict[builtins.str, builtins.str]"
1748+
case (_, dict()):
1749+
reveal_type(m7) # E: Statement is unreachable
1750+
case (_, _):
1751+
reveal_type(m7) # E: Statement is unreachable
1752+
1753+
match (m7, 4):
1754+
case ({"a": "1"}, _):
1755+
reveal_type(m7) # N: Revealed type is "builtins.dict[builtins.str, builtins.str]"
1756+
case r7:
1757+
reveal_type(r7) # N: Revealed type is "tuple[builtins.dict[builtins.str, builtins.str], Literal[4]?]"
1758+
17451759
m8: list[int]
17461760

17471761
match (m8, m8):
17481762
case ([1], _):
17491763
reveal_type(m8) # N: Revealed type is "builtins.list[builtins.int]"
17501764
case (_, [2]):
17511765
reveal_type(m8) # N: Revealed type is "builtins.list[builtins.int]"
1766+
1767+
match (m8, m8):
1768+
case (list(), _):
1769+
reveal_type(m8) # N: Revealed type is "builtins.list[builtins.int]"
1770+
case (_, [2]):
1771+
reveal_type(m8) # E: Statement is unreachable
17521772
[builtins fixtures/dict.pyi]
17531773

17541774
[case testMatchEnumSingleChoice]

0 commit comments

Comments
 (0)