Skip to content

Commit 7ddf821

Browse files
committed
Reapply "Narrow based on collection containment (python#17344)" (python#17865)
This reverts commit 329e38e.
1 parent eca206d commit 7ddf821

File tree

3 files changed

+127
-8
lines changed

3 files changed

+127
-8
lines changed

mypy/checker.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6063,11 +6063,16 @@ def has_no_custom_eq_checks(t: Type) -> bool:
60636063
if_map, else_map = {}, {}
60646064

60656065
if left_index in narrowable_operand_index_to_hash:
6066-
# We only try and narrow away 'None' for now
6067-
if is_overlapping_none(item_type):
6068-
collection_item_type = get_proper_type(
6069-
builtin_item_type(iterable_type)
6070-
)
6066+
collection_item_type = get_proper_type(builtin_item_type(iterable_type))
6067+
# Narrow if the collection is a subtype
6068+
if (
6069+
collection_item_type is not None
6070+
and collection_item_type != item_type
6071+
and is_subtype(collection_item_type, item_type)
6072+
):
6073+
if_map[operands[left_index]] = collection_item_type
6074+
# Try and narrow away 'None'
6075+
elif is_overlapping_none(item_type):
60716076
if (
60726077
collection_item_type is not None
60736078
and not is_overlapping_none(collection_item_type)

test-data/unit/check-narrowing.test

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,13 +1392,13 @@ else:
13921392
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
13931393

13941394
if val in (None,):
1395-
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
1395+
reveal_type(val) # N: Revealed type is "None"
13961396
else:
13971397
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
13981398
if val not in (None,):
13991399
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
14001400
else:
1401-
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
1401+
reveal_type(val) # N: Revealed type is "None"
14021402
[builtins fixtures/primitives.pyi]
14031403

14041404
[case testNarrowingWithTupleOfTypes]
@@ -2333,3 +2333,110 @@ def f(x: C) -> None:
23332333

23342334
f(C(5))
23352335
[builtins fixtures/primitives.pyi]
2336+
2337+
[case testTypeNarrowingStringInLiteralUnion]
2338+
from typing import Literal, Tuple
2339+
typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b')
2340+
x: str = "hi!"
2341+
if x in typ:
2342+
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
2343+
else:
2344+
reveal_type(x) # N: Revealed type is "builtins.str"
2345+
[builtins fixtures/tuple.pyi]
2346+
[typing fixtures/typing-medium.pyi]
2347+
2348+
[case testTypeNarrowingStringInLiteralUnionSubset]
2349+
from typing import Literal, Tuple
2350+
typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b')
2351+
strIn: str = "b"
2352+
strOut: str = "c"
2353+
if strIn in typeAlpha:
2354+
reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
2355+
else:
2356+
reveal_type(strIn) # N: Revealed type is "builtins.str"
2357+
if strOut in typeAlpha:
2358+
reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
2359+
else:
2360+
reveal_type(strOut) # N: Revealed type is "builtins.str"
2361+
[builtins fixtures/primitives.pyi]
2362+
[typing fixtures/typing-medium.pyi]
2363+
2364+
[case testNarrowingStringNotInLiteralUnion]
2365+
from typing import Literal, Tuple
2366+
typeAlpha: Tuple[Literal['a', 'b', 'c'],...] = ('a', 'b', 'c')
2367+
strIn: str = "c"
2368+
strOut: str = "d"
2369+
if strIn not in typeAlpha:
2370+
reveal_type(strIn) # N: Revealed type is "builtins.str"
2371+
else:
2372+
reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
2373+
if strOut in typeAlpha:
2374+
reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]"
2375+
else:
2376+
reveal_type(strOut) # N: Revealed type is "builtins.str"
2377+
[builtins fixtures/primitives.pyi]
2378+
[typing fixtures/typing-medium.pyi]
2379+
2380+
[case testNarrowingStringInLiteralUnionDontExpand]
2381+
from typing import Literal, Tuple
2382+
typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b', 'c')
2383+
strIn: Literal['c'] = "c"
2384+
reveal_type(strIn) # N: Revealed type is "Literal['c']"
2385+
#Check we don't expand a Literal into the Union type
2386+
if strIn not in typeAlpha:
2387+
reveal_type(strIn) # N: Revealed type is "Literal['c']"
2388+
else:
2389+
reveal_type(strIn) # N: Revealed type is "Literal['c']"
2390+
[builtins fixtures/primitives.pyi]
2391+
[typing fixtures/typing-medium.pyi]
2392+
2393+
[case testTypeNarrowingStringInMixedUnion]
2394+
from typing import Literal, Tuple
2395+
typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b')
2396+
x: str = "hi!"
2397+
if x in typ:
2398+
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
2399+
else:
2400+
reveal_type(x) # N: Revealed type is "builtins.str"
2401+
[builtins fixtures/tuple.pyi]
2402+
[typing fixtures/typing-medium.pyi]
2403+
2404+
[case testTypeNarrowingStringInSet]
2405+
from typing import Literal, Set
2406+
typ: Set[Literal['a', 'b']] = {'a', 'b'}
2407+
x: str = "hi!"
2408+
if x in typ:
2409+
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
2410+
else:
2411+
reveal_type(x) # N: Revealed type is "builtins.str"
2412+
if x not in typ:
2413+
reveal_type(x) # N: Revealed type is "builtins.str"
2414+
else:
2415+
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
2416+
[builtins fixtures/narrowing.pyi]
2417+
[typing fixtures/typing-medium.pyi]
2418+
2419+
[case testTypeNarrowingStringInList]
2420+
from typing import Literal, List
2421+
typ: List[Literal['a', 'b']] = ['a', 'b']
2422+
x: str = "hi!"
2423+
if x in typ:
2424+
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
2425+
else:
2426+
reveal_type(x) # N: Revealed type is "builtins.str"
2427+
if x not in typ:
2428+
reveal_type(x) # N: Revealed type is "builtins.str"
2429+
else:
2430+
reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]"
2431+
[builtins fixtures/narrowing.pyi]
2432+
[typing fixtures/typing-medium.pyi]
2433+
2434+
[case testTypeNarrowingUnionStringFloat]
2435+
from typing import Union
2436+
def foobar(foo: Union[str, float]):
2437+
if foo in ['a', 'b']:
2438+
reveal_type(foo) # N: Revealed type is "builtins.str"
2439+
else:
2440+
reveal_type(foo) # N: Revealed type is "Union[builtins.str, builtins.float]"
2441+
[builtins fixtures/primitives.pyi]
2442+
[typing fixtures/typing-medium.pyi]

test-data/unit/fixtures/narrowing.pyi

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Builtins stub used in check-narrowing test cases.
2-
from typing import Generic, Sequence, Tuple, Type, TypeVar, Union
2+
from typing import Generic, Sequence, Tuple, Type, TypeVar, Union, Iterable
33

44

55
Tco = TypeVar('Tco', covariant=True)
@@ -15,6 +15,13 @@ class function: pass
1515
class ellipsis: pass
1616
class int: pass
1717
class str: pass
18+
class float: pass
1819
class dict(Generic[KT, VT]): pass
1920

2021
def isinstance(x: object, t: Union[Type[object], Tuple[Type[object], ...]]) -> bool: pass
22+
23+
class list(Sequence[Tco]):
24+
def __contains__(self, other: object) -> bool: pass
25+
class set(Iterable[Tco], Generic[Tco]):
26+
def __init__(self, iterable: Iterable[Tco] = ...) -> None: ...
27+
def __contains__(self, item: object) -> bool: pass

0 commit comments

Comments
 (0)