Skip to content

Commit 853b762

Browse files
almost working
1 parent 6a179f7 commit 853b762

File tree

3 files changed

+123
-54
lines changed

3 files changed

+123
-54
lines changed

mypy/checkpattern.py

Lines changed: 73 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from __future__ import annotations
44

55
from collections import defaultdict
6-
from typing import Final, NamedTuple
6+
from itertools import zip_longest
7+
from typing import Final, Literal, NamedTuple
78

89
from mypy import message_registry
910
from mypy.checker_shared import TypeCheckerSharedApi, TypeRange
@@ -245,34 +246,79 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
245246
#
246247
# get inner types of original type
247248
#
248-
unpack_index = None
249-
if isinstance(current_type, TupleType):
250-
inner_types = current_type.items
251-
unpack_index = find_unpack_in_list(inner_types)
252-
if unpack_index is None:
253-
size_diff = len(inner_types) - required_patterns
254-
if size_diff < 0:
255-
return self.early_non_match()
256-
elif size_diff > 0 and star_position is None:
257-
return self.early_non_match()
249+
# 1. Go through all possible types and filter to only those which are sequences that could match that number of items
250+
# 2. If there are multiple tuples left with unpacks, then use the fallback logic where we union all items types
251+
# 3. Otherwise, take the product of the item types so that each index can have a unique type
252+
253+
# state of matching
254+
state: (
255+
# Start in the state where we not encountered an unpack.
256+
# a list of all the possible types that could match the sequence. If it's a tuple, then store one for each index
257+
tuple[Literal["NO_UNPACK"], list[list[Type]]] |
258+
# If we encounter a single tuple with an unpack, store the type, the unpack index, and the index in the union type
259+
tuple[Literal["UNPACK"], TupleType, int, int] |
260+
# If we have encountered a tuple with an unpack plus any other types, then store a list of them. For any tuples
261+
# without unpacks, store them as a list of their items.
262+
tuple[Literal["MULTI_UNPACK"], list[list[Type]]]
263+
) = ("NO_UNPACK", [])
264+
for i, t in enumerate(current_type.items) if isinstance(current_type, UnionType) else ((0, current_type),):
265+
t = get_proper_type(t)
266+
n_patterns = len(o.patterns)
267+
if isinstance(t, TupleType):
268+
t_items = t.items
269+
unpack_index = find_unpack_in_list(t_items)
270+
if unpack_index is None:
271+
size_diff = len(t_items) - required_patterns
272+
if size_diff < 0:
273+
continue
274+
elif size_diff > 0 and star_position is None:
275+
continue
276+
inner_t = list(t_items)
277+
else:
278+
normalized_inner_types = []
279+
for it in t_items:
280+
# Unfortunately, it is not possible to "split" the TypeVarTuple
281+
# into individual items, so we just use its upper bound for the whole
282+
# analysis instead.
283+
if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType):
284+
it = UnpackType(it.type.upper_bound)
285+
normalized_inner_types.append(it)
286+
t_items = normalized_inner_types
287+
t = t.copy_modified(items=normalized_inner_types)
288+
if len(t_items) - 1 > required_patterns and star_position is None:
289+
continue
290+
if state[0] == "NO_UNPACK" and not state[1]:
291+
state = ("UNPACK", t, unpack_index, i)
292+
continue
293+
inner_t = [self.chk.iterable_item_type(tuple_fallback(t), o)] * n_patterns
294+
elif isinstance(t, AnyType):
295+
inner_t = [AnyType(TypeOfAny.from_another_any, t)] * n_patterns
296+
elif self.chk.type_is_iterable(t) and isinstance(t, Instance):
297+
inner_t = [self.chk.iterable_item_type(t, o)] * n_patterns
258298
else:
259-
normalized_inner_types = []
260-
for it in inner_types:
261-
# Unfortunately, it is not possible to "split" the TypeVarTuple
262-
# into individual items, so we just use its upper bound for the whole
263-
# analysis instead.
264-
if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType):
265-
it = UnpackType(it.type.upper_bound)
266-
normalized_inner_types.append(it)
267-
inner_types = normalized_inner_types
268-
current_type = current_type.copy_modified(items=normalized_inner_types)
269-
if len(inner_types) - 1 > required_patterns and star_position is None:
270-
return self.early_non_match()
299+
continue
300+
# if we previously encountered an unpack, then change the state.
301+
if state[0] == "UNPACK":
302+
# if we already unpacked something, change this
303+
state = ("MULTI_UNPACK", [[self.chk.iterable_item_type(tuple_fallback(state[1]), o)] * n_patterns])
304+
assert state[0] != "UNPACK" # for type checker
305+
state[1].append(inner_t)
306+
307+
if state[0] == "UNPACK":
308+
_, update_tuple_type, unpack_index, union_index = state
309+
inner_types = update_tuple_type.items
310+
if isinstance(current_type, UnionType):
311+
union_items = list(current_type.items)
312+
union_items[union_index] = update_tuple_type
313+
current_type = current_type.copy_modified(items=union_items)
314+
else:
315+
assert unpack_index == 0, "Unpack index should be 0 for non-union types"
316+
current_type = update_tuple_type
271317
else:
272-
inner_type = self.get_sequence_type(current_type, o)
273-
if inner_type is None:
274-
inner_type = self.chk.named_type("builtins.object")
275-
inner_types = [inner_type] * len(o.patterns)
318+
unpack_index = None
319+
if not state[1]:
320+
return self.early_non_match()
321+
inner_types = [make_simplified_union(x) for x in zip_longest(*state[1], fillvalue=UninhabitedType())]
276322

277323
#
278324
# match inner patterns
@@ -351,25 +397,6 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
351397
new_type = self.narrow_sequence_child(current_type, new_inner_type, o)
352398
return PatternType(new_type, rest_type, captures)
353399

354-
def get_sequence_type(self, t: Type, context: Context) -> Type | None:
355-
t = get_proper_type(t)
356-
if isinstance(t, AnyType):
357-
return AnyType(TypeOfAny.from_another_any, t)
358-
if isinstance(t, UnionType):
359-
items = [self.get_sequence_type(item, context) for item in t.items]
360-
not_none_items = [item for item in items if item is not None]
361-
if not_none_items:
362-
return make_simplified_union(not_none_items)
363-
else:
364-
return None
365-
366-
if self.chk.type_is_iterable(t) and isinstance(t, (Instance, TupleType)):
367-
if isinstance(t, TupleType):
368-
t = tuple_fallback(t)
369-
return self.chk.iterable_item_type(t, context)
370-
else:
371-
return None
372-
373400
def contract_starred_pattern_types(
374401
self, types: list[Type], star_pos: int | None, num_patterns: int
375402
) -> list[Type]:

mypy/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2918,6 +2918,15 @@ def __init__(
29182918
self.original_str_expr: str | None = None
29192919
self.original_str_fallback: str | None = None
29202920

2921+
def copy_modified(self, *, items: Sequence[Type]) -> UnionType:
2922+
return UnionType(
2923+
items,
2924+
line=self.line,
2925+
column=self.column,
2926+
is_evaluated=self.is_evaluated,
2927+
uses_pep604_syntax=self.uses_pep604_syntax,
2928+
)
2929+
29212930
def can_be_true_default(self) -> bool:
29222931
return any(item.can_be_true for item in self.items)
29232932

test-data/unit/check-python310.test

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,29 +1560,62 @@ match m6:
15601560
[builtins fixtures/tuple.pyi]
15611561

15621562
[case testMatchTupleUnions]
1563+
from typing_extensions import Unpack
1564+
15631565
m1: tuple[int, str] | None
15641566
match m1:
15651567
case (a1, b1):
1566-
reveal_type(a1) # N: Revealed type is "int"
1567-
reveal_type(b1) # N: Revealed type is "str"
1568+
reveal_type(a1) # N: Revealed type is "builtins.int"
1569+
reveal_type(b1) # N: Revealed type is "builtins.str"
15681570

15691571
m2: tuple[int, str] | tuple[float, str]
15701572
match m2:
15711573
case (a2, b2):
1572-
reveal_type(a2) # N: Revealed type is "Union[int, float]"
1573-
reveal_type(b2) # N: Revealed type is "str"
1574+
reveal_type(a2) # N: Revealed type is "Union[builtins.int, builtins.float]"
1575+
reveal_type(b2) # N: Revealed type is "builtins.str"
15741576

15751577
m3: tuple[int] | tuple[float, str]
15761578
match m3:
15771579
case (a3, b3):
1578-
reveal_type(a3) # N: Revealed type is "float"
1579-
reveal_type(b3) # N: Revealed type is "str"
1580+
reveal_type(a3) # N: Revealed type is "builtins.float"
1581+
reveal_type(b3) # N: Revealed type is "builtins.str"
15801582

15811583
m4: tuple[int] | list[str]
15821584
match m4:
15831585
case (a4, b4):
1584-
reveal_type(a4) # N: Revealed type is "float"
1585-
reveal_type(b4) # N: Revealed type is "str"
1586+
reveal_type(a4) # N: Revealed type is "builtins.str"
1587+
reveal_type(b4) # N: Revealed type is "builtins.str"
1588+
1589+
# properly handles unpack when all other patterns are not sequences
1590+
m5: tuple[int, Unpack[tuple[float, ...]]] | None
1591+
match m5:
1592+
case (a5, b5):
1593+
reveal_type(a5) # N: Revealed type is "builtins.int"
1594+
reveal_type(b5) # N: Revealed type is "builtins.float"
1595+
1596+
# currently can't handle combing unpacking with other sequence patterns, if this happens revert to worst case
1597+
# of combing all types
1598+
m6: tuple[int, Unpack[tuple[float, ...]]] | list[str]
1599+
match m6:
1600+
case (a6, b6):
1601+
reveal_type(a6) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.str]"
1602+
reveal_type(b6) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.str]"
1603+
1604+
# but do still seperate types from non unpacked types
1605+
m7: tuple[int, Unpack[tuple[float, ...]]] | tuple[str, bool]
1606+
match m7:
1607+
case (a7, b7, *rest7):
1608+
reveal_type(a7) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.str]"
1609+
reveal_type(b7) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.bool]"
1610+
reveal_type(rest7) # N: Revealed type is "builtins.list[Union[builtins.int, builtins.float]]"
1611+
1612+
# verify that if we are unpacking, it will get the type of the sequence if the tuple is too short
1613+
m8: tuple[int, str] | list[float]
1614+
match m8:
1615+
case (a8, b8, *rest8):
1616+
reveal_type(a8) # N: Revealed type is "Union[builtins.int, builtins.float]"
1617+
reveal_type(b8) # N: Revealed type is "Union[builtins.str, builtins.float]"
1618+
reveal_type(rest8) # N: Revealed type is "builtins.list[builtins.float]"
15861619

15871620
[builtins fixtures/tuple.pyi]
15881621

0 commit comments

Comments
 (0)