Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
find_member,
infer_class_variances,
is_callable_compatible,
is_enum_value_pair,
is_equivalent,
is_more_precise,
is_proper_subtype,
Expand Down Expand Up @@ -6628,6 +6629,7 @@ def equality_type_narrowing_helper(
if operator in {"is", "is not"}:
is_valid_target: Callable[[Type], bool] = is_singleton_type
coerce_only_in_literal_context = False
no_custom_eq = True
should_narrow_by_identity = True
else:

Expand All @@ -6643,21 +6645,31 @@ def has_no_custom_eq_checks(t: Type) -> bool:
coerce_only_in_literal_context = True

expr_types = [operand_types[i] for i in expr_indices]
should_narrow_by_identity = all(
map(has_no_custom_eq_checks, expr_types)
) and not is_ambiguous_mix_of_enums(expr_types)
no_custom_eq = all(map(has_no_custom_eq_checks, expr_types))
should_narrow_by_identity = not is_ambiguous_mix_of_enums(expr_types)

if_map: TypeMap = {}
else_map: TypeMap = {}
if should_narrow_by_identity:
if_map, else_map = self.refine_identity_comparison_expression(
if no_custom_eq:
# Try to narrow the types or at least identify unreachable blocks.
# If there's some mix of enums and values, we do not want to narrow enums
# to literals, but still want to detect unreachable branches.
if_map_optimistic, else_map_optimistic = self.refine_identity_comparison_expression(
operands,
operand_types,
expr_indices,
narrowable_operand_index_to_hash.keys(),
is_valid_target,
coerce_only_in_literal_context,
)
if should_narrow_by_identity:
if_map = if_map_optimistic
else_map = else_map_optimistic
else:
if if_map_optimistic is None:
if_map = None
if else_map_optimistic is None:
else_map = None

if if_map == {} and else_map == {}:
if_map, else_map = self.refine_away_none_in_comparison(
Expand Down Expand Up @@ -6905,13 +6917,16 @@ def should_coerce_inner(typ: Type) -> bool:
expr_type = coerce_to_literal(expr_type)
if not is_valid_target(get_proper_type(expr_type)):
continue
if target and not is_same_type(target, expr_type):
if (
target is not None
and not is_same_type(target, expr_type)
and not is_enum_value_pair(target, expr_type)
):
# We have multiple disjoint target types. So the 'if' branch
# must be unreachable.
return None, {}
target = expr_type
possible_target_indices.append(i)

# There's nothing we can currently infer if none of the operands are valid targets,
# so we end early and infer nothing.
if target is None:
Expand Down Expand Up @@ -9291,7 +9306,8 @@ def _ambiguous_enum_variants(types: list[Type]) -> set[str]:
if t.last_known_value:
result.update(_ambiguous_enum_variants([t.last_known_value]))
elif t.type.is_enum and any(
base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in t.type.mro
base.fullname in ("enum.IntEnum", "enum.StrEnum", "builtins.str", "builtins.int")
for base in t.type.mro
):
result.add(t.type.fullname)
elif not t.type.is_enum:
Expand Down
10 changes: 9 additions & 1 deletion mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
are_parameters_compatible,
find_member,
is_callable_compatible,
is_enum_value_pair,
is_equivalent,
is_proper_subtype,
is_same_type,
Expand Down Expand Up @@ -559,9 +560,16 @@ def _type_object_overlap(left: Type, right: Type) -> bool:
right = right.fallback

if isinstance(left, LiteralType) and isinstance(right, LiteralType):
if left.value == right.value:
if (
left.value == right.value
and left.fallback.type.is_enum == right.fallback.type.is_enum
or is_enum_value_pair(left, right)
):
# If values are the same, we still need to check if fallbacks are overlapping,
# this is done below.
# Enums are more interesting:
# * if both sides are enums, they should have same values
# * if exactly one of them is a enum, fallback compatibibility is enough
left = left.fallback
right = right.fallback
else:
Expand Down
30 changes: 30 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from mypy.options import Options
from mypy.state import state
from mypy.types import (
ELLIPSIS_TYPE_NAMES,
MYPYC_NATIVE_INT_NAMES,
TUPLE_LIKE_INSTANCE_NAMES,
TYPED_NAMEDTUPLE_NAMES,
Expand Down Expand Up @@ -286,6 +287,35 @@ def is_same_type(
)


def is_enum_value_pair(a: Type, b: Type) -> bool:
a = get_proper_type(a)
b = get_proper_type(b)

if not isinstance(a, LiteralType) or not isinstance(b, LiteralType):
return False
if b.fallback.type.is_enum:
a, b = b, a
if b.fallback.type.is_enum or not a.fallback.type.is_enum:
return False
# At this point we have a pair (enum literal, non-enum literal).
# Check that the non-enum fallback is compatible
if not is_subtype(a.fallback, b.fallback):
return False
assert isinstance(a.value, str)
enum_value = a.fallback.type.get(a.value)
if enum_value is None or enum_value.type is None:
return False
proper_value = get_proper_type(enum_value.type)
return isinstance(proper_value, Instance) and (
proper_value.last_known_value == b
# TODO: this is too lax and should only be applied for enums defined in stubs,
# but checking that strictly requires access to the checker. This function
# is needed in `is_overlapping_types` and operates on a lower level,
# so doing this properly would be more difficult.
or proper_value.type.fullname in ELLIPSIS_TYPE_NAMES
)


# This is a common entry point for subtyping checks (both proper and non-proper).
# Never call this private function directly, use the public versions.
def _is_subtype(
Expand Down
Loading