Skip to content

Commit c9d1233

Browse files
Simplify logic
1 parent 543d224 commit c9d1233

File tree

2 files changed

+36
-58
lines changed

2 files changed

+36
-58
lines changed

mypy/checkpattern.py

Lines changed: 34 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from __future__ import annotations
44

55
from collections import defaultdict
6-
from itertools import zip_longest
7-
from typing import Final, Literal, NamedTuple
6+
from typing import Final, NamedTuple
87

98
from mypy import message_registry
109
from mypy.checker_shared import TypeCheckerSharedApi, TypeRange
@@ -245,44 +244,35 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
245244

246245
#
247246
# get inner types of original type
248-
#
249247
# 1. Go through all possible types and filter to only those which are sequences that could match that number of items
250248
# 2. If there is exactly one tuple left with an unpack, then use that type and the unpack index
251249
# 3. Otherwise, take the product of the item types so that each index can have a unique type. For tuples with unpack
252250
# fallback to merging all of their types for each index since we can't handle multiple unpacked items at once yet.
253251

254-
# state of matching
255-
# whether one of the possible types is not handled, in which case we want to return an object
256-
unknown_value = False
257-
state: (
258-
# Start in the state where we not encountered an unpack.
259-
# a list of all the possible types that could match the sequence. If it's a tuple, then store one for each index
260-
tuple[Literal["NO_UNPACK"], list[list[Type]]]
261-
|
262-
# If we encounter a single tuple with an unpack, store the type, the unpack index, and the index in the union type
263-
tuple[Literal["UNPACK"], TupleType, int, int]
264-
|
265-
# If we have encountered a tuple with an unpack plus any other types, then store a list of them. For any tuples
266-
# without unpacks, store them as a list of their items.
267-
tuple[Literal["MULTI_UNPACK"], list[list[Type]]]
268-
) = ("NO_UNPACK", [])
269-
for i, t in (
270-
enumerate(current_type.items)
271-
if isinstance(current_type, UnionType)
272-
else ((0, current_type),)
252+
# Whether we have encountered a type that we don't know how to handle in the union
253+
unknown_type = False
254+
# A list of types that could match any of the items in the sequence.
255+
sequence_types: list[Type] = []
256+
# A list of tuple types that could match the sequence, per index
257+
tuple_types: list[list[Type]] = []
258+
# A list of all the unpack tuple types that we encountered, each containing the tuple type, unpack index, and union index
259+
unpack_tuple_types: list[tuple[TupleType, int, int]] = []
260+
for i, t in enumerate(
261+
current_type.items if isinstance(current_type, UnionType) else [current_type]
273262
):
274263
t = get_proper_type(t)
275-
n_patterns = len(o.patterns)
276264
if isinstance(t, TupleType):
277-
t_items = t.items
265+
t_items = list(t.items)
278266
unpack_index = find_unpack_in_list(t_items)
279267
if unpack_index is None:
280268
size_diff = len(t_items) - required_patterns
281269
if size_diff < 0:
282270
continue
283271
elif size_diff > 0 and star_position is None:
284272
continue
285-
inner_t = list(t_items)
273+
elif not size_diff and star_position is not None:
274+
t_items.append(UninhabitedType()) # add additional item for star if its empty
275+
tuple_types.append(t_items)
286276
else:
287277
normalized_inner_types = []
288278
for it in t_items:
@@ -296,47 +286,35 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
296286
t = t.copy_modified(items=normalized_inner_types)
297287
if len(t_items) - 1 > required_patterns and star_position is None:
298288
continue
299-
if state[0] == "NO_UNPACK" and not state[1]:
300-
state = ("UNPACK", t, unpack_index, i)
301-
continue
302-
inner_t = [self.chk.iterable_item_type(tuple_fallback(t), o)] * n_patterns
289+
unpack_tuple_types.append((t, unpack_index, i))
290+
# add the combined tuple type to the sequence types in case we have multiple unpacks we want to combine them all
291+
sequence_types.append(self.chk.iterable_item_type(tuple_fallback(t), o))
303292
elif isinstance(t, AnyType):
304-
inner_t = [AnyType(TypeOfAny.from_another_any, t)] * n_patterns
293+
sequence_types.append(AnyType(TypeOfAny.from_another_any, t))
305294
elif self.chk.type_is_iterable(t) and isinstance(t, Instance):
306-
inner_t = [self.chk.iterable_item_type(t, o)] * n_patterns
295+
sequence_types.append(self.chk.iterable_item_type(t, o))
307296
else:
308-
unknown_value = True
309-
continue
310-
# if we previously encountered an unpack, then change the state.
311-
if state[0] == "UNPACK":
312-
# if we already unpacked something, change this
313-
state = (
314-
"MULTI_UNPACK",
315-
[[self.chk.iterable_item_type(tuple_fallback(state[1]), o)] * n_patterns],
316-
)
317-
assert state[0] != "UNPACK" # for type checker
318-
state[1].append(inner_t)
319-
if state[0] == "UNPACK":
320-
_, update_tuple_type, unpack_index, union_index = state
321-
inner_types = update_tuple_type.items
297+
unknown_type = True
298+
# if we only got one unpack tuple type, we can use that
299+
unpack_index = None
300+
if len(unpack_tuple_types) == 1 and len(sequence_types) == 1 and not tuple_types:
301+
update_tuple_type, unpack_index, union_index = unpack_tuple_types[0]
302+
inner_types: list[Type] = update_tuple_type.items
322303
if isinstance(current_type, UnionType):
323304
union_items = list(current_type.items)
324305
union_items[union_index] = update_tuple_type
325306
current_type = current_type.copy_modified(items=union_items)
326307
else:
327-
assert union_index == 0, "Unpack index should be 0 for non-union types"
328308
current_type = update_tuple_type
309+
# if we only got tuples we can't match, then exit early
310+
elif not tuple_types and not sequence_types and not unknown_type:
311+
return self.early_non_match()
312+
elif tuple_types:
313+
inner_types = [make_simplified_union([*sequence_types, *x]) for x in zip(*tuple_types)]
329314
else:
330-
unpack_index = None
331-
if state[1]:
332-
inner_types = [
333-
make_simplified_union(x)
334-
for x in zip_longest(*state[1], fillvalue=UninhabitedType())
335-
]
336-
elif unknown_value:
337-
inner_types = [self.chk.named_type("builtins.object")] * n_patterns
338-
else:
339-
return self.early_non_match()
315+
object_type = self.chk.named_type("builtins.object")
316+
inner_type = make_simplified_union(sequence_types) if sequence_types else object_type
317+
inner_types = [inner_type] * len(o.patterns)
340318

341319
#
342320
# match inner patterns

test-data/unit/check-python310.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,8 +1613,8 @@ match m7:
16131613
m8: tuple[int, str] | list[float]
16141614
match m8:
16151615
case (a8, b8, *rest8):
1616-
reveal_type(a8) # N: Revealed type is "Union[builtins.int, builtins.float]"
1617-
reveal_type(b8) # N: Revealed type is "Union[builtins.str, builtins.float]"
1616+
reveal_type(a8) # N: Revealed type is "Union[builtins.float, builtins.int]"
1617+
reveal_type(b8) # N: Revealed type is "Union[builtins.float, builtins.str]"
16181618
reveal_type(rest8) # N: Revealed type is "builtins.list[builtins.float]"
16191619

16201620
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)