1010from mypy .checkmember import analyze_member_access
1111from mypy .expandtype import expand_type_by_instance
1212from mypy .join import join_types
13- from mypy .literals import literal_hash
13+ from mypy .literals import Key , literal_hash
1414from mypy .maptype import map_instance_to_supertype
1515from mypy .meet import narrow_declared_type
1616from mypy .messages import MessageBuilder
17- from mypy .nodes import ARG_POS , Context , Expression , NameExpr , TypeAlias , TypeInfo , Var
17+ from mypy .nodes import (
18+ ARG_POS ,
19+ Context ,
20+ Expression ,
21+ IndexExpr ,
22+ IntExpr ,
23+ ListExpr ,
24+ MemberExpr ,
25+ NameExpr ,
26+ TupleExpr ,
27+ TypeAlias ,
28+ TypeInfo ,
29+ UnaryExpr ,
30+ Var ,
31+ )
1832from mypy .options import Options
1933from mypy .patterns import (
2034 AsPattern ,
@@ -96,10 +110,8 @@ class PatternChecker(PatternVisitor[PatternType]):
96110 msg : MessageBuilder
97111 # Currently unused
98112 plugin : Plugin
99- # The expression being matched against the pattern
100- subject : Expression
101-
102- subject_type : Type
113+ # The expressions being matched against the (sub)pattern
114+ subject_context : list [list [Expression ]]
103115 # Type of the subject to check the (sub)pattern against
104116 type_context : list [Type ]
105117 # Types that match against self instead of their __match_args__ if used as a class pattern
@@ -118,24 +130,28 @@ def __init__(
118130 self .msg = msg
119131 self .plugin = plugin
120132
133+ self .subject_context = []
121134 self .type_context = []
122135 self .self_match_types = self .generate_types_from_names (self_match_type_names )
123136 self .non_sequence_match_types = self .generate_types_from_names (
124137 non_sequence_match_type_names
125138 )
126139 self .options = options
127140
128- def accept (self , o : Pattern , type_context : Type ) -> PatternType :
141+ def accept (self , o : Pattern , type_context : Type , subject : list [Expression ]) -> PatternType :
142+ self .subject_context .append (subject )
129143 self .type_context .append (type_context )
130144 result = o .accept (self )
145+ self .subject_context .pop ()
131146 self .type_context .pop ()
132147
133148 return result
134149
135150 def visit_as_pattern (self , o : AsPattern ) -> PatternType :
151+ current_subject = self .subject_context [- 1 ]
136152 current_type = self .type_context [- 1 ]
137153 if o .pattern is not None :
138- pattern_type = self .accept (o .pattern , current_type )
154+ pattern_type = self .accept (o .pattern , current_type , current_subject )
139155 typ , rest_type , type_map = pattern_type
140156 else :
141157 typ , rest_type , type_map = current_type , UninhabitedType (), {}
@@ -150,14 +166,15 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType:
150166 return PatternType (typ , rest_type , type_map )
151167
152168 def visit_or_pattern (self , o : OrPattern ) -> PatternType :
169+ current_subject = self .subject_context [- 1 ]
153170 current_type = self .type_context [- 1 ]
154171
155172 #
156173 # Check all the subpatterns
157174 #
158- pattern_types = []
175+ pattern_types : list [ PatternType ] = []
159176 for pattern in o .patterns :
160- pattern_type = self .accept (pattern , current_type )
177+ pattern_type = self .accept (pattern , current_type , current_subject )
161178 pattern_types .append (pattern_type )
162179 if not is_uninhabited (pattern_type .type ):
163180 current_type = pattern_type .rest_type
@@ -173,28 +190,42 @@ def visit_or_pattern(self, o: OrPattern) -> PatternType:
173190 #
174191 # Check the capture types
175192 #
176- capture_types : dict [Var , list [tuple [Expression , Type ]]] = defaultdict (list )
193+ capture_types : dict [Var , dict [Key | None , list [tuple [Expression , Type ]]]] = defaultdict (
194+ lambda : defaultdict (list )
195+ )
196+ capture_expr_keys : set [Key | None ] = set ()
177197 # Collect captures from the first subpattern
178198 for expr , typ in pattern_types [0 ].captures .items ():
179- node = get_var (expr )
180- capture_types [node ].append ((expr , typ ))
199+ if (node := get_var (expr )) is None :
200+ continue
201+ key = literal_hash (expr )
202+ capture_types [node ][key ].append ((expr , typ ))
203+ if isinstance (expr , NameExpr ):
204+ capture_expr_keys .add (key )
181205
182206 # Check if other subpatterns capture the same names
183207 for i , pattern_type in enumerate (pattern_types [1 :]):
184- vars = {get_var (expr ) for expr , _ in pattern_type .captures .items ()}
185- if capture_types .keys () != vars :
208+ vars = {
209+ literal_hash (expr ) for expr in pattern_type .captures if isinstance (expr , NameExpr )
210+ }
211+ if capture_expr_keys != vars :
212+ # Only fail for directly captured names (with NameExpr)
186213 self .msg .fail (message_registry .OR_PATTERN_ALTERNATIVE_NAMES , o .patterns [i ])
187214 for expr , typ in pattern_type .captures .items ():
188- node = get_var (expr )
189- capture_types [node ].append ((expr , typ ))
215+ if (node := get_var (expr )) is None :
216+ continue
217+ key = literal_hash (expr )
218+ capture_types [node ][key ].append ((expr , typ ))
190219
191220 captures : dict [Expression , Type ] = {}
192- for capture_list in capture_types .values ():
193- typ = UninhabitedType ()
194- for _ , other in capture_list :
195- typ = make_simplified_union ([typ , other ])
221+ for expressions in capture_types .values ():
222+ for key , capture_list in expressions .items ():
223+ if other_types := [entry [1 ] for entry in capture_list ]:
224+ typ = make_simplified_union (other_types )
225+ else :
226+ typ = UninhabitedType ()
196227
197- captures [capture_list [0 ][0 ]] = typ
228+ captures [capture_list [0 ][0 ]] = typ
198229
199230 union_type = make_simplified_union (types )
200231 return PatternType (union_type , current_type , captures )
@@ -284,12 +315,37 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
284315 contracted_inner_types = self .contract_starred_pattern_types (
285316 inner_types , star_position , required_patterns
286317 )
287- for p , t in zip (o .patterns , contracted_inner_types ):
288- pattern_type = self .accept (p , t )
318+ current_subjects : list [list [Expression ]] = [[] for _ in range (len (contracted_inner_types ))]
319+ end_pos = len (contracted_inner_types ) if star_position is None else star_position
320+ for subject in self .subject_context [- 1 ]:
321+ if isinstance (subject , (ListExpr , TupleExpr )):
322+ # For list and tuple expressions, lookup expression in items
323+ for i in range (end_pos ):
324+ if i < len (subject .items ):
325+ current_subjects [i ].append (subject .items [i ])
326+ if star_position is not None :
327+ for i in range (star_position + 1 , len (contracted_inner_types )):
328+ offset = len (contracted_inner_types ) - i
329+ if offset <= len (subject .items ):
330+ current_subjects [i ].append (subject .items [- offset ])
331+ else :
332+ # Support x[0], x[1], ... lookup until wildcard
333+ for i in range (end_pos ):
334+ current_subjects [i ].append (IndexExpr (subject , IntExpr (i )))
335+ # For everything after wildcard use x[-2], x[-1]
336+ for i in range ((star_position or - 1 ) + 1 , len (contracted_inner_types )):
337+ offset = len (contracted_inner_types ) - i
338+ current_subjects [i ].append (IndexExpr (subject , UnaryExpr ("-" , IntExpr (offset ))))
339+ for p , t , s in zip (o .patterns , contracted_inner_types , current_subjects ):
340+ pattern_type = self .accept (p , t , s )
289341 typ , rest , type_map = pattern_type
290342 contracted_new_inner_types .append (typ )
291343 contracted_rest_inner_types .append (rest )
292344 self .update_type_map (captures , type_map )
345+ if s :
346+ self .update_type_map (
347+ captures , {subject : typ for subject in s }, fail_multiple_assignments = False
348+ )
293349
294350 new_inner_types = self .expand_starred_pattern_types (
295351 contracted_new_inner_types , star_position , len (inner_types ), unpack_index is not None
@@ -473,11 +529,18 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
473529 if inner_type is None :
474530 can_match = False
475531 inner_type = self .chk .named_type ("builtins.object" )
476- pattern_type = self .accept (value , inner_type )
532+ current_subjects : list [Expression ] = [
533+ IndexExpr (s , key ) for s in self .subject_context [- 1 ]
534+ ]
535+ pattern_type = self .accept (value , inner_type , current_subjects )
477536 if is_uninhabited (pattern_type .type ):
478537 can_match = False
479538 else :
480539 self .update_type_map (captures , pattern_type .captures )
540+ if current_subjects :
541+ self .update_type_map (
542+ captures , {subject : pattern_type .type for subject in current_subjects }
543+ )
481544
482545 if o .rest is not None :
483546 mapping = self .chk .named_type ("typing.Mapping" )
@@ -581,7 +644,7 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
581644 if self .should_self_match (typ ):
582645 if len (o .positionals ) > 1 :
583646 self .msg .fail (message_registry .CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS , o )
584- pattern_type = self .accept (o .positionals [0 ], narrowed_type )
647+ pattern_type = self .accept (o .positionals [0 ], narrowed_type , [] )
585648 if not is_uninhabited (pattern_type .type ):
586649 return PatternType (
587650 pattern_type .type ,
@@ -681,11 +744,20 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
681744 elif keyword is not None :
682745 new_type = self .chk .add_any_attribute_to_type (new_type , keyword )
683746
684- inner_type , inner_rest_type , inner_captures = self .accept (pattern , key_type )
747+ current_subjects : list [Expression ] = []
748+ if keyword is not None :
749+ current_subjects = [MemberExpr (s , keyword ) for s in self .subject_context [- 1 ]]
750+ inner_type , inner_rest_type , inner_captures = self .accept (
751+ pattern , key_type , current_subjects
752+ )
685753 if is_uninhabited (inner_type ):
686754 can_match = False
687755 else :
688756 self .update_type_map (captures , inner_captures )
757+ if current_subjects :
758+ self .update_type_map (
759+ captures , {subject : inner_type for subject in current_subjects }
760+ )
689761 if not is_uninhabited (inner_rest_type ):
690762 rest_type = current_type
691763
@@ -732,17 +804,22 @@ def generate_types_from_names(self, type_names: list[str]) -> list[Type]:
732804 return types
733805
734806 def update_type_map (
735- self , original_type_map : dict [Expression , Type ], extra_type_map : dict [Expression , Type ]
807+ self ,
808+ original_type_map : dict [Expression , Type ],
809+ extra_type_map : dict [Expression , Type ],
810+ fail_multiple_assignments : bool = True ,
736811 ) -> None :
737812 # Calculating this would not be needed if TypeMap directly used literal hashes instead of
738813 # expressions, as suggested in the TODO above it's definition
739814 already_captured = {literal_hash (expr ) for expr in original_type_map }
740815 for expr , typ in extra_type_map .items ():
741816 if literal_hash (expr ) in already_captured :
742- node = get_var (expr )
743- self .msg .fail (
744- message_registry .MULTIPLE_ASSIGNMENTS_IN_PATTERN .format (node .name ), expr
745- )
817+ if (node := get_var (expr )) is None :
818+ continue
819+ if fail_multiple_assignments :
820+ self .msg .fail (
821+ message_registry .MULTIPLE_ASSIGNMENTS_IN_PATTERN .format (node .name ), expr
822+ )
746823 else :
747824 original_type_map [expr ] = typ
748825
@@ -794,12 +871,17 @@ def get_match_arg_names(typ: TupleType) -> list[str | None]:
794871 return args
795872
796873
797- def get_var (expr : Expression ) -> Var :
874+ def get_var (expr : Expression ) -> Var | None :
798875 """
799876 Warning: this in only true for expressions captured by a match statement.
800877 Don't call it from anywhere else
801878 """
802- assert isinstance (expr , NameExpr ), expr
879+ if isinstance (expr , MemberExpr ):
880+ return get_var (expr .expr )
881+ if isinstance (expr , IndexExpr ):
882+ return get_var (expr .base )
883+ if not isinstance (expr , NameExpr ):
884+ return None
803885 node = expr .node
804886 assert isinstance (node , Var ), node
805887 return node
0 commit comments