Skip to content

Commit 7a5f33e

Browse files
committed
Rewrite type equals narrowing with type ranges
1 parent 91fbc3e commit 7a5f33e

File tree

2 files changed

+57
-87
lines changed

2 files changed

+57
-87
lines changed

mypy/checker.py

Lines changed: 55 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import functools
56
import itertools
67
from collections import defaultdict
78
from collections.abc import Iterable, Iterator, Mapping, Sequence, Set as AbstractSet
@@ -25,7 +26,6 @@
2526
from mypy.binder import ConditionalTypeBinder, Frame, get_declaration
2627
from mypy.checker_shared import CheckerScope, TypeCheckerSharedApi, TypeRange
2728
from mypy.checker_state import checker_state
28-
from mypy.checkexpr import type_info_from_type
2929
from mypy.checkmember import (
3030
MemberContext,
3131
analyze_class_attribute_access,
@@ -6239,115 +6239,86 @@ def is_type_call(expr: CallExpr) -> bool:
62396239
# exprs that are being passed into type
62406240
exprs_in_type_calls: list[Expression] = []
62416241
# all the types that an expression will have if the overall expression is truthy
6242-
target_types: list[Instance] = []
6242+
target_types: list[list[TypeRange]] = []
62436243
# only a single type can be used when passed directly (eg "str")
6244-
fixed_type: TypeRange | None = None
6244+
fixed_type: Type | None = None
62456245
# is this single type final?
62466246
is_final = False
62476247

6248-
def update_fixed_type(new_fixed_type: TypeRange, new_is_final: bool) -> bool:
6248+
def update_fixed_type(new_fixed_type: Type, new_is_final: bool) -> bool:
62496249
"""Returns if the update succeeds"""
62506250
nonlocal fixed_type, is_final
6251-
if update := (
6252-
fixed_type is None
6253-
or (
6254-
new_fixed_type.is_upper_bound == fixed_type.is_upper_bound
6255-
and is_same_type(new_fixed_type.item, fixed_type.item)
6256-
)
6257-
):
6251+
if update := (fixed_type is None or (is_same_type(new_fixed_type, fixed_type))):
62586252
fixed_type = new_fixed_type
62596253
is_final = new_is_final
62606254
return update
62616255

62626256
for index in expr_indices:
62636257
expr = node.operands[index]
6258+
proper_type = get_proper_type(self.lookup_type(expr))
62646259

62656260
if isinstance(expr, CallExpr) and is_type_call(expr):
62666261
arg = expr.args[0]
62676262
exprs_in_type_calls.append(arg)
6268-
typ = self.lookup_type(arg)
6269-
else:
6270-
proper_type = get_proper_type(self.lookup_type(expr))
6271-
# get the range as though we were using isinstance
6272-
type_range = self.isinstance_type_range(proper_type)
6273-
# None range means this should not be used in comparison (eg tuple)
6274-
if type_range is None:
6275-
fixed_type = TypeRange(UninhabitedType(), True)
6276-
continue
6263+
elif (
6264+
isinstance(expr, OpExpr)
6265+
or isinstance(proper_type, TupleType)
6266+
or is_named_instance(proper_type, "builtins.tuple")
6267+
):
6268+
# not valid for type comparisons, but allowed for isinstance checks
6269+
fixed_type = UninhabitedType()
6270+
continue
62776271

6278-
if isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo):
6279-
if not update_fixed_type(type_range, expr.node.is_final):
6272+
type_range = self.get_isinstance_type(expr)
6273+
if type_range is not None:
6274+
target_types.append(type_range)
6275+
if (
6276+
isinstance(expr, RefExpr)
6277+
and isinstance(expr.node, TypeInfo)
6278+
and len(type_range) == 1
6279+
):
6280+
if not update_fixed_type(
6281+
Instance(
6282+
expr.node,
6283+
[AnyType(TypeOfAny.special_form)] * len(expr.node.defn.type_vars),
6284+
),
6285+
expr.node.is_final,
6286+
):
62806287
return None, {}
6281-
typ = type_range.item
6282-
6283-
type_info = type_info_from_type(typ)
6284-
if type_info is not None:
6285-
target_types.append(
6286-
Instance(
6287-
type_info,
6288-
[AnyType(TypeOfAny.special_form)] * len(type_info.defn.type_vars),
6289-
)
6290-
)
62916288

62926289
if not exprs_in_type_calls:
62936290
return {}, {}
62946291

6295-
if target_types:
6296-
least_type: Type | None = target_types[0]
6297-
for target_type in target_types[1:]:
6298-
if fixed_type is None:
6299-
# intersect types if fixed type doesn't need keeping
6300-
least_type = self.intersect_instances(
6301-
(cast(Instance, least_type), target_type), [] # what to do with errors?
6302-
)
6303-
else:
6304-
# otherwise, be safe and use meet
6305-
least_type = meet_types(cast(Type, least_type), target_type)
6306-
if least_type is None:
6307-
break
6308-
elif fixed_type:
6309-
least_type = fixed_type.item
6310-
else:
6311-
# no bounds means no inference can be made
6312-
return {}, {}
6313-
6314-
# if the type differs from the fixed type, comparison cannot succeed
6315-
if least_type is None or (
6316-
fixed_type is not None and not is_same_type(least_type, fixed_type.item)
6317-
):
6318-
return None, {}
6319-
6320-
shared_type = [TypeRange(least_type, not is_final)]
6321-
6322-
if_maps: list[TypeMap] = []
6323-
else_maps: list[TypeMap] = []
6292+
if_maps = []
6293+
else_maps = []
63246294
for expr in exprs_in_type_calls:
6325-
if_map, else_map = conditional_types_to_typemaps(
6326-
expr,
6327-
*self.conditional_types_with_intersection(
6328-
self.lookup_type(expr), shared_type, expr
6329-
),
6330-
)
6295+
expr_type = get_proper_type(self.lookup_type(expr))
6296+
for type_range in target_types:
6297+
new_expr_type, _ = self.conditional_types_with_intersection(
6298+
expr_type, type_range, expr
6299+
)
6300+
if new_expr_type is not None:
6301+
new_expr_type = get_proper_type(new_expr_type)
6302+
if isinstance(expr_type, AnyType):
6303+
expr_type = new_expr_type
6304+
elif not isinstance(new_expr_type, AnyType):
6305+
expr_type = meet_types(expr_type, new_expr_type)
6306+
_, else_map = conditional_types_to_typemaps(
6307+
expr,
6308+
*self.conditional_types_with_intersection(
6309+
(self.lookup_type(expr)), (type_range), expr
6310+
),
6311+
)
6312+
else_maps.append(else_map)
6313+
if fixed_type and expr_type is not None:
6314+
expr_type = meet_types(expr_type, fixed_type)
6315+
6316+
if_map, _ = conditional_types_to_typemaps(expr, expr_type, None)
63316317
if_maps.append(if_map)
6332-
else_maps.append(else_map)
63336318

6334-
def combine_maps(list_maps: list[TypeMap]) -> TypeMap:
6335-
"""Combine all typemaps in list_maps into one typemap"""
6336-
if all(m is None for m in list_maps):
6337-
return None
6338-
result_map = {}
6339-
for d in list_maps:
6340-
if d is not None:
6341-
result_map.update(d)
6342-
return result_map
6343-
6344-
if_map = combine_maps(if_maps)
6345-
# type(x) == T is only true when x has the same type as T, meaning
6346-
# that it can be false if x is an instance of a subclass of T. That means
6347-
# we can't do any narrowing in the else case unless T is final, in which
6348-
# case T can't be subclassed
6319+
if_map = functools.reduce(and_conditional_maps, if_maps)
63496320
if is_final:
6350-
else_map = combine_maps(else_maps)
6321+
else_map = functools.reduce(or_conditional_maps, else_maps)
63516322
else:
63526323
else_map = {}
63536324
return if_map, else_map

test-data/unit/check-isinstance.test

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2772,7 +2772,7 @@ x: X = z
27722772
y: Y = z
27732773
if type(x) is type(y):
27742774
reveal_type(x) # N: Revealed type is "__main__.<subclass of "__main__.X" and "__main__.Y">"
2775-
reveal_type(y) # N: Revealed type is "__main__.<subclass of "__main__.X" and "__main__.Y">"
2775+
reveal_type(y) # N: Revealed type is "__main__.<subclass of "__main__.Y" and "__main__.X">"
27762776
x.y + y.x
27772777

27782778
if isinstance(x, type(y)) and isinstance(y, type(x)):
@@ -2788,10 +2788,9 @@ from typing import Union
27882788

27892789
y: str
27902790
if type(y) is int:
2791-
y # E: Statement is unreachable
2791+
y
27922792
else:
27932793
reveal_type(y) # N: Revealed type is "builtins.str"
2794-
[builtins fixtures/isinstance.pyi]
27952794

27962795
[case testTypeEqualsCheckUsingIsNonOverlappingChild-xfail]
27972796
# flags: --warn-unreachable

0 commit comments

Comments
 (0)