diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 48840466f0d8..8f38807168b0 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -244,35 +244,77 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: # # get inner types of original type - # + # 1. Go through all possible types and filter to only those which are sequences that could match that number of items + # 2. If there is exactly one tuple left with an unpack, then use that type and the unpack index + # 3. Otherwise, take the product of the item types so that each index can have a unique type. For tuples with unpack + # fallback to merging all of their types for each index since we can't handle multiple unpacked items at once yet. + + # Whether we have encountered a type that we don't know how to handle in the union + unknown_type = False + # A list of types that could match any of the items in the sequence. + sequence_types: list[Type] = [] + # A list of tuple types that could match the sequence, per index + tuple_types: list[list[Type]] = [] + # A list of all the unpack tuple types that we encountered, each containing the tuple type, unpack index, and union index + unpack_tuple_types: list[tuple[TupleType, int, int]] = [] + for i, t in enumerate( + current_type.items if isinstance(current_type, UnionType) else [current_type] + ): + t = get_proper_type(t) + if isinstance(t, TupleType): + t_items = list(t.items) + unpack_index = find_unpack_in_list(t_items) + if unpack_index is None: + size_diff = len(t_items) - required_patterns + if size_diff < 0: + continue + elif size_diff > 0 and star_position is None: + continue + elif not size_diff and star_position is not None: + t_items.append(UninhabitedType()) + tuple_types.append(t_items) + else: + normalized_inner_types = [] + for it in t_items: + # Unfortunately, it is not possible to "split" the TypeVarTuple + # into individual items, so we just use its upper bound for the whole + # analysis instead. + if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType): + it = UnpackType(it.type.upper_bound) + normalized_inner_types.append(it) + t_items = normalized_inner_types + t = t.copy_modified(items=normalized_inner_types) + if len(t_items) - 1 > required_patterns and star_position is None: + continue + unpack_tuple_types.append((t, unpack_index, i)) + # add the combined tuple type to the sequence types in case we have multiple unpacks we want to combine them all + sequence_types.append(self.chk.iterable_item_type(tuple_fallback(t), o)) + elif isinstance(t, AnyType): + sequence_types.append(AnyType(TypeOfAny.from_another_any, t)) + elif self.chk.type_is_iterable(t) and isinstance(t, Instance): + sequence_types.append(self.chk.iterable_item_type(t, o)) + else: + unknown_type = True + # if we only got one unpack tuple type, we can use that unpack_index = None - if isinstance(current_type, TupleType): - inner_types = current_type.items - unpack_index = find_unpack_in_list(inner_types) - if unpack_index is None: - size_diff = len(inner_types) - required_patterns - if size_diff < 0: - return self.early_non_match() - elif size_diff > 0 and star_position is None: - return self.early_non_match() + if len(unpack_tuple_types) == 1 and len(sequence_types) == 1 and not tuple_types: + update_tuple_type, unpack_index, union_index = unpack_tuple_types[0] + inner_types: list[Type] = update_tuple_type.items + if isinstance(current_type, UnionType): + union_items = list(current_type.items) + union_items[union_index] = update_tuple_type + current_type = current_type.copy_modified(items=union_items) else: - normalized_inner_types = [] - for it in inner_types: - # Unfortunately, it is not possible to "split" the TypeVarTuple - # into individual items, so we just use its upper bound for the whole - # analysis instead. - if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType): - it = UnpackType(it.type.upper_bound) - normalized_inner_types.append(it) - inner_types = normalized_inner_types - current_type = current_type.copy_modified(items=normalized_inner_types) - if len(inner_types) - 1 > required_patterns and star_position is None: - return self.early_non_match() + current_type = update_tuple_type + # if we only got tuples we can't match, then exit early + elif not tuple_types and not sequence_types and not unknown_type: + return self.early_non_match() + elif tuple_types: + inner_types = [make_simplified_union([*sequence_types, *x]) for x in zip(*tuple_types)] else: - inner_type = self.get_sequence_type(current_type, o) - if inner_type is None: - inner_type = self.chk.named_type("builtins.object") - inner_types = [inner_type] * len(o.patterns) + object_type = self.chk.named_type("builtins.object") + unioned = make_simplified_union(sequence_types) if sequence_types else object_type + inner_types = [unioned] * len(o.patterns) # # match inner patterns @@ -351,25 +393,6 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: new_type = self.narrow_sequence_child(current_type, new_inner_type, o) return PatternType(new_type, rest_type, captures) - def get_sequence_type(self, t: Type, context: Context) -> Type | None: - t = get_proper_type(t) - if isinstance(t, AnyType): - return AnyType(TypeOfAny.from_another_any, t) - if isinstance(t, UnionType): - items = [self.get_sequence_type(item, context) for item in t.items] - not_none_items = [item for item in items if item is not None] - if not_none_items: - return make_simplified_union(not_none_items) - else: - return None - - if self.chk.type_is_iterable(t) and isinstance(t, (Instance, TupleType)): - if isinstance(t, TupleType): - t = tuple_fallback(t) - return self.chk.iterable_item_type(t, context) - else: - return None - def contract_starred_pattern_types( self, types: list[Type], star_pos: int | None, num_patterns: int ) -> list[Type]: diff --git a/mypy/types.py b/mypy/types.py index b4771b15f77a..5f25285404aa 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2918,6 +2918,15 @@ def __init__( self.original_str_expr: str | None = None self.original_str_fallback: str | None = None + def copy_modified(self, *, items: Sequence[Type]) -> UnionType: + return UnionType( + items, + line=self.line, + column=self.column, + is_evaluated=self.is_evaluated, + uses_pep604_syntax=self.uses_pep604_syntax, + ) + def can_be_true_default(self) -> bool: return any(item.can_be_true for item in self.items) diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index f264167cb067..dfa2d16e375c 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1559,6 +1559,67 @@ match m6: [builtins fixtures/tuple.pyi] +[case testMatchTupleUnions] +from typing_extensions import Unpack + +m1: tuple[int, str] | None +match m1: + case (a1, b1): + reveal_type(a1) # N: Revealed type is "builtins.int" + reveal_type(b1) # N: Revealed type is "builtins.str" + +m2: tuple[int, str] | tuple[float, str] +match m2: + case (a2, b2): + reveal_type(a2) # N: Revealed type is "Union[builtins.int, builtins.float]" + reveal_type(b2) # N: Revealed type is "builtins.str" + +m3: tuple[int] | tuple[float, str] +match m3: + case (a3, b3): + reveal_type(a3) # N: Revealed type is "builtins.float" + reveal_type(b3) # N: Revealed type is "builtins.str" + +m4: tuple[int] | list[str] +match m4: + case (a4, b4): + reveal_type(a4) # N: Revealed type is "builtins.str" + reveal_type(b4) # N: Revealed type is "builtins.str" + +# properly handles unpack when all other patterns are not sequences +m5: tuple[int, Unpack[tuple[float, ...]]] | None +match m5: + case (a5, b5): + reveal_type(a5) # N: Revealed type is "builtins.int" + reveal_type(b5) # N: Revealed type is "builtins.float" + +# currently can't handle combing unpacking with other sequence patterns, if this happens revert to worst case +# of combing all types +m6: tuple[int, Unpack[tuple[float, ...]]] | list[str] +match m6: + case (a6, b6): + reveal_type(a6) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.str]" + reveal_type(b6) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.str]" + +# but do still separate types from non unpacked types +m7: tuple[int, Unpack[tuple[float, ...]]] | tuple[str, str] +match m7: + case (a7, b7, *rest7): + reveal_type(a7) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.str]" + reveal_type(b7) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.str]" + reveal_type(rest7) # N: Revealed type is "builtins.list[Union[builtins.int, builtins.float]]" + +# verify that if we are unpacking, it will get the type of the sequence if the tuple is too short +m8: tuple[int, str] | list[float] +match m8: + case (a8, b8, *rest8): + reveal_type(a8) # N: Revealed type is "Union[builtins.float, builtins.int]" + reveal_type(b8) # N: Revealed type is "Union[builtins.float, builtins.str]" + reveal_type(rest8) # N: Revealed type is "builtins.list[builtins.float]" + +[builtins fixtures/tuple.pyi] + + [case testMatchEnumSingleChoice] from enum import Enum from typing import NoReturn