33from __future__ import annotations
44
55from collections import defaultdict
6- from itertools import zip_longest
7- from typing import Final , Literal , NamedTuple
6+ from typing import Final , NamedTuple
87
98from mypy import message_registry
109from mypy .checker_shared import TypeCheckerSharedApi , TypeRange
@@ -245,44 +244,35 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
245244
246245 #
247246 # get inner types of original type
248- #
249247 # 1. Go through all possible types and filter to only those which are sequences that could match that number of items
250248 # 2. If there is exactly one tuple left with an unpack, then use that type and the unpack index
251249 # 3. Otherwise, take the product of the item types so that each index can have a unique type. For tuples with unpack
252250 # fallback to merging all of their types for each index since we can't handle multiple unpacked items at once yet.
253251
254- # state of matching
255- # whether one of the possible types is not handled, in which case we want to return an object
256- unknown_value = False
257- state : (
258- # Start in the state where we not encountered an unpack.
259- # a list of all the possible types that could match the sequence. If it's a tuple, then store one for each index
260- tuple [Literal ["NO_UNPACK" ], list [list [Type ]]]
261- |
262- # If we encounter a single tuple with an unpack, store the type, the unpack index, and the index in the union type
263- tuple [Literal ["UNPACK" ], TupleType , int , int ]
264- |
265- # If we have encountered a tuple with an unpack plus any other types, then store a list of them. For any tuples
266- # without unpacks, store them as a list of their items.
267- tuple [Literal ["MULTI_UNPACK" ], list [list [Type ]]]
268- ) = ("NO_UNPACK" , [])
269- for i , t in (
270- enumerate (current_type .items )
271- if isinstance (current_type , UnionType )
272- else ((0 , current_type ),)
252+ # Whether we have encountered a type that we don't know how to handle in the union
253+ unknown_type = False
254+ # A list of types that could match any of the items in the sequence.
255+ sequence_types : list [Type ] = []
256+ # A list of tuple types that could match the sequence, per index
257+ tuple_types : list [list [Type ]] = []
258+ # A list of all the unpack tuple types that we encountered, each containing the tuple type, unpack index, and union index
259+ unpack_tuple_types : list [tuple [TupleType , int , int ]] = []
260+ for i , t in enumerate (
261+ current_type .items if isinstance (current_type , UnionType ) else [current_type ]
273262 ):
274263 t = get_proper_type (t )
275- n_patterns = len (o .patterns )
276264 if isinstance (t , TupleType ):
277- t_items = t .items
265+ t_items = list ( t .items )
278266 unpack_index = find_unpack_in_list (t_items )
279267 if unpack_index is None :
280268 size_diff = len (t_items ) - required_patterns
281269 if size_diff < 0 :
282270 continue
283271 elif size_diff > 0 and star_position is None :
284272 continue
285- inner_t = list (t_items )
273+ elif not size_diff and star_position is not None :
274+ t_items .append (UninhabitedType ()) # add additional item for star if its empty
275+ tuple_types .append (t_items )
286276 else :
287277 normalized_inner_types = []
288278 for it in t_items :
@@ -296,47 +286,35 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
296286 t = t .copy_modified (items = normalized_inner_types )
297287 if len (t_items ) - 1 > required_patterns and star_position is None :
298288 continue
299- if state [0 ] == "NO_UNPACK" and not state [1 ]:
300- state = ("UNPACK" , t , unpack_index , i )
301- continue
302- inner_t = [self .chk .iterable_item_type (tuple_fallback (t ), o )] * n_patterns
289+ unpack_tuple_types .append ((t , unpack_index , i ))
290+ # add the combined tuple type to the sequence types in case we have multiple unpacks we want to combine them all
291+ sequence_types .append (self .chk .iterable_item_type (tuple_fallback (t ), o ))
303292 elif isinstance (t , AnyType ):
304- inner_t = [ AnyType (TypeOfAny .from_another_any , t )] * n_patterns
293+ sequence_types . append ( AnyType (TypeOfAny .from_another_any , t ))
305294 elif self .chk .type_is_iterable (t ) and isinstance (t , Instance ):
306- inner_t = [ self .chk .iterable_item_type (t , o )] * n_patterns
295+ sequence_types . append ( self .chk .iterable_item_type (t , o ))
307296 else :
308- unknown_value = True
309- continue
310- # if we previously encountered an unpack, then change the state.
311- if state [0 ] == "UNPACK" :
312- # if we already unpacked something, change this
313- state = (
314- "MULTI_UNPACK" ,
315- [[self .chk .iterable_item_type (tuple_fallback (state [1 ]), o )] * n_patterns ],
316- )
317- assert state [0 ] != "UNPACK" # for type checker
318- state [1 ].append (inner_t )
319- if state [0 ] == "UNPACK" :
320- _ , update_tuple_type , unpack_index , union_index = state
321- inner_types = update_tuple_type .items
297+ unknown_type = True
298+ # if we only got one unpack tuple type, we can use that
299+ unpack_index = None
300+ if len (unpack_tuple_types ) == 1 and len (sequence_types ) == 1 and not tuple_types :
301+ update_tuple_type , unpack_index , union_index = unpack_tuple_types [0 ]
302+ inner_types : list [Type ] = update_tuple_type .items
322303 if isinstance (current_type , UnionType ):
323304 union_items = list (current_type .items )
324305 union_items [union_index ] = update_tuple_type
325306 current_type = current_type .copy_modified (items = union_items )
326307 else :
327- assert union_index == 0 , "Unpack index should be 0 for non-union types"
328308 current_type = update_tuple_type
309+ # if we only got tuples we can't match, then exit early
310+ elif not tuple_types and not sequence_types and not unknown_type :
311+ return self .early_non_match ()
312+ elif tuple_types :
313+ inner_types = [make_simplified_union ([* sequence_types , * x ]) for x in zip (* tuple_types )]
329314 else :
330- unpack_index = None
331- if state [1 ]:
332- inner_types = [
333- make_simplified_union (x )
334- for x in zip_longest (* state [1 ], fillvalue = UninhabitedType ())
335- ]
336- elif unknown_value :
337- inner_types = [self .chk .named_type ("builtins.object" )] * n_patterns
338- else :
339- return self .early_non_match ()
315+ object_type = self .chk .named_type ("builtins.object" )
316+ inner_type = make_simplified_union (sequence_types ) if sequence_types else object_type
317+ inner_types = [inner_type ] * len (o .patterns )
340318
341319 #
342320 # match inner patterns
0 commit comments