Skip to content

Fix matching against union of tuples #19600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 68 additions & 45 deletions mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,35 +244,77 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:

#
# get inner types of original type
#
# 1. Go through all possible types and filter to only those which are sequences that could match that number of items
# 2. If there is exactly one tuple left with an unpack, then use that type and the unpack index
# 3. Otherwise, take the product of the item types so that each index can have a unique type. For tuples with unpack
# fallback to merging all of their types for each index since we can't handle multiple unpacked items at once yet.

# Whether we have encountered a type that we don't know how to handle in the union
unknown_type = False
# A list of types that could match any of the items in the sequence.
sequence_types: list[Type] = []
# A list of tuple types that could match the sequence, per index
tuple_types: list[list[Type]] = []
# A list of all the unpack tuple types that we encountered, each containing the tuple type, unpack index, and union index
unpack_tuple_types: list[tuple[TupleType, int, int]] = []
for i, t in enumerate(
current_type.items if isinstance(current_type, UnionType) else [current_type]
):
t = get_proper_type(t)
if isinstance(t, TupleType):
t_items = list(t.items)
unpack_index = find_unpack_in_list(t_items)
if unpack_index is None:
size_diff = len(t_items) - required_patterns
if size_diff < 0:
continue
elif size_diff > 0 and star_position is None:
continue
elif not size_diff and star_position is not None:
t_items.append(UninhabitedType())
tuple_types.append(t_items)
else:
normalized_inner_types = []
for it in t_items:
# Unfortunately, it is not possible to "split" the TypeVarTuple
# into individual items, so we just use its upper bound for the whole
# analysis instead.
if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType):
it = UnpackType(it.type.upper_bound)
normalized_inner_types.append(it)
t_items = normalized_inner_types
t = t.copy_modified(items=normalized_inner_types)
if len(t_items) - 1 > required_patterns and star_position is None:
continue
unpack_tuple_types.append((t, unpack_index, i))
# add the combined tuple type to the sequence types in case we have multiple unpacks we want to combine them all
sequence_types.append(self.chk.iterable_item_type(tuple_fallback(t), o))
elif isinstance(t, AnyType):
sequence_types.append(AnyType(TypeOfAny.from_another_any, t))
elif self.chk.type_is_iterable(t) and isinstance(t, Instance):
sequence_types.append(self.chk.iterable_item_type(t, o))
else:
unknown_type = True
# if we only got one unpack tuple type, we can use that
unpack_index = None
if isinstance(current_type, TupleType):
inner_types = current_type.items
unpack_index = find_unpack_in_list(inner_types)
if unpack_index is None:
size_diff = len(inner_types) - required_patterns
if size_diff < 0:
return self.early_non_match()
elif size_diff > 0 and star_position is None:
return self.early_non_match()
if len(unpack_tuple_types) == 1 and len(sequence_types) == 1 and not tuple_types:
update_tuple_type, unpack_index, union_index = unpack_tuple_types[0]
inner_types: list[Type] = update_tuple_type.items
if isinstance(current_type, UnionType):
union_items = list(current_type.items)
union_items[union_index] = update_tuple_type
current_type = current_type.copy_modified(items=union_items)
else:
normalized_inner_types = []
for it in inner_types:
# Unfortunately, it is not possible to "split" the TypeVarTuple
# into individual items, so we just use its upper bound for the whole
# analysis instead.
if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType):
it = UnpackType(it.type.upper_bound)
normalized_inner_types.append(it)
inner_types = normalized_inner_types
current_type = current_type.copy_modified(items=normalized_inner_types)
if len(inner_types) - 1 > required_patterns and star_position is None:
return self.early_non_match()
current_type = update_tuple_type
# if we only got tuples we can't match, then exit early
elif not tuple_types and not sequence_types and not unknown_type:
return self.early_non_match()
elif tuple_types:
inner_types = [make_simplified_union([*sequence_types, *x]) for x in zip(*tuple_types)]
else:
inner_type = self.get_sequence_type(current_type, o)
if inner_type is None:
inner_type = self.chk.named_type("builtins.object")
inner_types = [inner_type] * len(o.patterns)
object_type = self.chk.named_type("builtins.object")
unioned = make_simplified_union(sequence_types) if sequence_types else object_type
inner_types = [unioned] * len(o.patterns)

#
# match inner patterns
Expand Down Expand Up @@ -351,25 +393,6 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
new_type = self.narrow_sequence_child(current_type, new_inner_type, o)
return PatternType(new_type, rest_type, captures)

def get_sequence_type(self, t: Type, context: Context) -> Type | None:
t = get_proper_type(t)
if isinstance(t, AnyType):
return AnyType(TypeOfAny.from_another_any, t)
if isinstance(t, UnionType):
items = [self.get_sequence_type(item, context) for item in t.items]
not_none_items = [item for item in items if item is not None]
if not_none_items:
return make_simplified_union(not_none_items)
else:
return None

if self.chk.type_is_iterable(t) and isinstance(t, (Instance, TupleType)):
if isinstance(t, TupleType):
t = tuple_fallback(t)
return self.chk.iterable_item_type(t, context)
else:
return None

def contract_starred_pattern_types(
self, types: list[Type], star_pos: int | None, num_patterns: int
) -> list[Type]:
Expand Down
9 changes: 9 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2918,6 +2918,15 @@ def __init__(
self.original_str_expr: str | None = None
self.original_str_fallback: str | None = None

def copy_modified(self, *, items: Sequence[Type]) -> UnionType:
return UnionType(
items,
line=self.line,
column=self.column,
is_evaluated=self.is_evaluated,
uses_pep604_syntax=self.uses_pep604_syntax,
)

def can_be_true_default(self) -> bool:
return any(item.can_be_true for item in self.items)

Expand Down
61 changes: 61 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,67 @@ match m6:

[builtins fixtures/tuple.pyi]

[case testMatchTupleUnions]
from typing_extensions import Unpack

m1: tuple[int, str] | None
match m1:
case (a1, b1):
reveal_type(a1) # N: Revealed type is "builtins.int"
reveal_type(b1) # N: Revealed type is "builtins.str"

m2: tuple[int, str] | tuple[float, str]
match m2:
case (a2, b2):
reveal_type(a2) # N: Revealed type is "Union[builtins.int, builtins.float]"
reveal_type(b2) # N: Revealed type is "builtins.str"

m3: tuple[int] | tuple[float, str]
match m3:
case (a3, b3):
reveal_type(a3) # N: Revealed type is "builtins.float"
reveal_type(b3) # N: Revealed type is "builtins.str"

m4: tuple[int] | list[str]
match m4:
case (a4, b4):
reveal_type(a4) # N: Revealed type is "builtins.str"
reveal_type(b4) # N: Revealed type is "builtins.str"

# properly handles unpack when all other patterns are not sequences
m5: tuple[int, Unpack[tuple[float, ...]]] | None
match m5:
case (a5, b5):
reveal_type(a5) # N: Revealed type is "builtins.int"
reveal_type(b5) # N: Revealed type is "builtins.float"

# currently can't handle combing unpacking with other sequence patterns, if this happens revert to worst case
# of combing all types
m6: tuple[int, Unpack[tuple[float, ...]]] | list[str]
match m6:
case (a6, b6):
reveal_type(a6) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.str]"
reveal_type(b6) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.str]"

# but do still separate types from non unpacked types
m7: tuple[int, Unpack[tuple[float, ...]]] | tuple[str, str]
match m7:
case (a7, b7, *rest7):
reveal_type(a7) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.str]"
reveal_type(b7) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.str]"
reveal_type(rest7) # N: Revealed type is "builtins.list[Union[builtins.int, builtins.float]]"

# verify that if we are unpacking, it will get the type of the sequence if the tuple is too short
m8: tuple[int, str] | list[float]
match m8:
case (a8, b8, *rest8):
reveal_type(a8) # N: Revealed type is "Union[builtins.float, builtins.int]"
reveal_type(b8) # N: Revealed type is "Union[builtins.float, builtins.str]"
reveal_type(rest8) # N: Revealed type is "builtins.list[builtins.float]"

[builtins fixtures/tuple.pyi]


[case testMatchEnumSingleChoice]
from enum import Enum
from typing import NoReturn
Expand Down
Loading