Skip to content

Commit 7be0e87

Browse files
inital commit
1 parent 8bd159b commit 7be0e87

File tree

2 files changed

+56
-13
lines changed

2 files changed

+56
-13
lines changed

mypy/checker.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7994,20 +7994,47 @@ def conditional_types(
79947994
) -> tuple[Type | None, Type | None]:
79957995
"""Takes in the current type and a proposed type of an expression.
79967996
7997-
Returns a 2-tuple: The first element is the proposed type, if the expression
7998-
can be the proposed type. The second element is the type it would hold
7999-
if it was not the proposed type, if any. UninhabitedType means unreachable.
8000-
None means no new information can be inferred. If default is set it is returned
8001-
instead."""
7997+
Returns a 2-tuple:
7998+
The first element is the proposed type, if the expression can be the proposed type.
7999+
The second element is the type it would hold if it was not the proposed type, if any.
8000+
UninhabitedType means unreachable.
8001+
None means no new information can be inferred.
8002+
If default is set it is returned instead.
8003+
"""
8004+
if proposed_type_ranges and len(proposed_type_ranges) == 1:
8005+
# expand e.g. bool -> Literal[True] | Literal[False]
8006+
target = proposed_type_ranges[0].item
8007+
target = get_proper_type(target)
8008+
if isinstance(target, LiteralType) and (
8009+
target.is_enum_literal() or isinstance(target.value, bool)
8010+
):
8011+
enum_name = target.fallback.type.fullname
8012+
current_type = try_expanding_sum_type_to_union(current_type, enum_name)
8013+
8014+
current_type = get_proper_type(current_type)
8015+
if isinstance(current_type, UnionType) and (default == current_type):
8016+
# factorize over union types
8017+
# if we try to narrow A|B to C, we instead narrow A to C and B to C, and
8018+
# return the union of the results
8019+
result: list[tuple[Type | None, Type | None]] = [
8020+
conditional_types(
8021+
union_item,
8022+
proposed_type_ranges,
8023+
default=union_item,
8024+
consider_runtime_isinstance=consider_runtime_isinstance,
8025+
)
8026+
for union_item in get_proper_types(current_type.items)
8027+
]
8028+
# separate list of tuples into two lists
8029+
yes_types, no_types = zip(*result)
8030+
yes_type = make_simplified_union([t for t in yes_types if t is not None])
8031+
no_type = restrict_subtype_away(
8032+
current_type, yes_type, consider_runtime_isinstance=consider_runtime_isinstance
8033+
)
8034+
8035+
return yes_type, no_type
8036+
80028037
if proposed_type_ranges:
8003-
if len(proposed_type_ranges) == 1:
8004-
target = proposed_type_ranges[0].item
8005-
target = get_proper_type(target)
8006-
if isinstance(target, LiteralType) and (
8007-
target.is_enum_literal() or isinstance(target.value, bool)
8008-
):
8009-
enum_name = target.fallback.type.fullname
8010-
current_type = try_expanding_sum_type_to_union(current_type, enum_name)
80118038
proposed_items = [type_range.item for type_range in proposed_type_ranges]
80128039
proposed_type = make_simplified_union(proposed_items)
80138040
if isinstance(proposed_type, AnyType):

test-data/unit/check-python310.test

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,22 @@ def union(x: str | bool) -> None:
17601760
reveal_type(x) # N: Revealed type is "Union[builtins.str, Literal[False]]"
17611761
[builtins fixtures/tuple.pyi]
17621762

1763+
[case testMatchNarrowDownUnionUsingClassPattern]
1764+
1765+
class Foo: ...
1766+
class Bar(Foo): ...
1767+
1768+
def test_1(bar: Bar) -> None:
1769+
match bar:
1770+
case Foo() as foo:
1771+
reveal_type(foo) # N: Revealed type is "__main__.Bar"
1772+
1773+
def test_2(bar: Bar | str) -> None:
1774+
match bar:
1775+
case Foo() as foo:
1776+
reveal_type(foo) # N: Revealed type is "__main__.Bar"
1777+
1778+
17631779
[case testMatchAssertFalseToSilenceFalsePositives]
17641780
class C:
17651781
a: int | str

0 commit comments

Comments
 (0)