Skip to content

Commit 2bc24d8

Browse files
Refactor: modularize check_op_reversible into helper methods
1 parent fbb411f commit 2bc24d8

File tree

1 file changed

+118
-82
lines changed

1 file changed

+118
-82
lines changed

mypy/checkexpr.py

Lines changed: 118 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3968,54 +3968,14 @@ def check_op_reversible(
39683968
right_expr: Expression,
39693969
context: Context,
39703970
) -> tuple[Type, Type]:
3971-
def lookup_operator(op_name: str, base_type: Type) -> Type | None:
3972-
"""Looks up the given operator and returns the corresponding type,
3973-
if it exists."""
3974-
3975-
# This check is an important performance optimization,
3976-
# even though it is mostly a subset of
3977-
# analyze_member_access.
3978-
# TODO: Find a way to remove this call without performance implications.
3979-
if not self.has_member(base_type, op_name):
3980-
return None
3981-
3982-
with self.msg.filter_errors() as w:
3983-
member = analyze_member_access(
3984-
name=op_name,
3985-
typ=base_type,
3986-
is_lvalue=False,
3987-
is_super=False,
3988-
is_operator=True,
3989-
original_type=base_type,
3990-
context=context,
3991-
chk=self.chk,
3992-
in_literal_context=self.is_literal_context(),
3993-
)
3994-
return None if w.has_new_errors() else member
3995-
3996-
def lookup_definer(typ: Instance, attr_name: str) -> str | None:
3997-
"""Returns the name of the class that contains the actual definition of attr_name.
3998-
3999-
So if class A defines foo and class B subclasses A, running
4000-
'get_class_defined_in(B, "foo")` would return the full name of A.
4001-
4002-
However, if B were to override and redefine foo, that method call would
4003-
return the full name of B instead.
4004-
4005-
If the attr name is not present in the given class or its MRO, returns None.
4006-
"""
4007-
for cls in typ.type.mro:
4008-
if cls.names.get(attr_name):
4009-
return cls.fullname
4010-
return None
3971+
"""Check a binary operator for types where evaluation order matters."""
40113972

40123973
left_type = get_proper_type(left_type)
40133974
right_type = get_proper_type(right_type)
40143975

40153976
# If either the LHS or the RHS are Any, we can't really concluding anything
40163977
# about the operation since the Any type may or may not define an
40173978
# __op__ or __rop__ method. So, we punt and return Any instead.
4018-
40193979
if isinstance(left_type, AnyType):
40203980
any_type = AnyType(TypeOfAny.from_another_any, source_any=left_type)
40213981
return any_type, any_type
@@ -4025,82 +3985,157 @@ def lookup_definer(typ: Instance, attr_name: str) -> str | None:
40253985

40263986
# STEP 1:
40273987
# We start by getting the __op__ and __rop__ methods, if they exist.
4028-
40293988
rev_op_name = operators.reverse_op_methods[op_name]
3989+
left_op = self._lookup_operator(op_name, left_type, context)
3990+
right_op = self._lookup_operator(rev_op_name, right_type, context)
40303991

4031-
left_op = lookup_operator(op_name, left_type)
4032-
right_op = lookup_operator(rev_op_name, right_type)
4033-
4034-
# STEP 2a:
3992+
# STEP 2:
40353993
# We figure out in which order Python will call the operator methods. As it
40363994
# turns out, it's not as simple as just trying to call __op__ first and
40373995
# __rop__ second.
40383996
#
40393997
# We store the determined order inside the 'variants_raw' variable,
40403998
# which records tuples containing the method, base type, and the argument.
3999+
variants_raw = self._determine_operator_order(
4000+
op_name, rev_op_name, left_type, right_type, left_expr, right_expr, left_op, right_op
4001+
)
4002+
4003+
# STEP 3:
4004+
# We now filter out all non-existent operators. The 'variants' list contains
4005+
# all operator methods that are actually present, in the order that Python
4006+
# attempts to invoke them.
4007+
variants = [
4008+
(name, op, obj, arg) for (name, op, obj, arg) in variants_raw if op is not None
4009+
]
40414010

4011+
# STEP 4:
4012+
# We now try invoking each one. If an operation succeeds, end early and return
4013+
# the corresponding result. Otherwise, return the result and errors associated
4014+
# with the first entry.
4015+
return self._attempt_operator_applications(
4016+
op_name, variants, left_type, right_type, left_expr, right_expr, context
4017+
)
4018+
4019+
4020+
def _lookup_operator(self, op_name: str, base_type: Type, context: Context) -> Type | None:
4021+
"""Look up the given operator and return the corresponding type, if it exists."""
4022+
4023+
# This check is an important performance optimization,
4024+
# even though it is mostly a subset of analyze_member_access.
4025+
# TODO: Find a way to remove this call without performance implications.
4026+
if not self.has_member(base_type, op_name):
4027+
return None
4028+
4029+
with self.msg.filter_errors() as w:
4030+
member = analyze_member_access(
4031+
name=op_name,
4032+
typ=base_type,
4033+
is_lvalue=False,
4034+
is_super=False,
4035+
is_operator=True,
4036+
original_type=base_type,
4037+
context=context,
4038+
chk=self.chk,
4039+
in_literal_context=self.is_literal_context(),
4040+
)
4041+
return None if w.has_new_errors() else member
4042+
4043+
4044+
def _lookup_definer(self, typ: Instance, attr_name: str) -> str | None:
4045+
"""Returns the name of the class that contains the actual definition of attr_name.
4046+
4047+
So if class A defines foo and class B subclasses A, running
4048+
'get_class_defined_in(B, "foo")` would return the full name of A.
4049+
4050+
However, if B were to override and redefine foo, that method call would
4051+
return the full name of B instead.
4052+
4053+
If the attr name is not present in the given class or its MRO, returns None.
4054+
"""
4055+
for cls in typ.type.mro:
4056+
if cls.names.get(attr_name):
4057+
return cls.fullname
4058+
return None
4059+
4060+
4061+
def _determine_operator_order(
4062+
self,
4063+
op_name: str,
4064+
rev_op_name: str,
4065+
left_type: Type,
4066+
right_type: Type,
4067+
left_expr: Expression,
4068+
right_expr: Expression,
4069+
left_op: Type | None,
4070+
right_op: Type | None,
4071+
) -> list[tuple[str, Type | None, Type, Expression]]:
4072+
"""Determine in which order Python will attempt to call the operator methods."""
4073+
4074+
# When we do "A() + A()", for example, Python will only call the __add__ method,
4075+
# never the __radd__ method. This is the case even if the __add__ method is missing
4076+
# and the __radd__ method is defined.
40424077
if op_name in operators.op_methods_that_shortcut and is_same_type(left_type, right_type):
4043-
# When we do "A() + A()", for example, Python will only call the __add__ method,
4044-
# never the __radd__ method.
4045-
#
4046-
# This is the case even if the __add__ method is completely missing and the __radd__
4047-
# method is defined.
4078+
return [(op_name, left_op, left_type, right_expr)]
40484079

4049-
variants_raw = [(op_name, left_op, left_type, right_expr)]
4050-
elif (
4080+
left_type = get_proper_type(left_type)
4081+
right_type = get_proper_type(right_type)
4082+
4083+
if (
40514084
is_subtype(right_type, left_type)
40524085
and isinstance(left_type, Instance)
40534086
and isinstance(right_type, Instance)
40544087
and not (
40554088
left_type.type.alt_promote is not None
40564089
and left_type.type.alt_promote.type is right_type.type
40574090
)
4058-
and lookup_definer(left_type, op_name) != lookup_definer(right_type, rev_op_name)
4091+
and self._lookup_definer(left_type, op_name)
4092+
!= self._lookup_definer(right_type, rev_op_name)
40594093
):
40604094
# When we do "A() + B()" where B is a subclass of A, we'll actually try calling
4061-
# B's __radd__ method first, but ONLY if B explicitly defines or overrides the
4062-
# __radd__ method.
4095+
# B's __radd__ method first, but ONLY if B explicitly defines or overrides it.
40634096
#
4064-
# This mechanism lets subclasses "refine" the expected outcome of the operation, even
4065-
# if they're located on the RHS.
4097+
# This mechanism lets subclasses "refine" the expected outcome of the operation,
4098+
# even if they're located on the RHS.
40664099
#
40674100
# As a special case, the alt_promote check makes sure that we don't use the
40684101
# __radd__ method of int if the LHS is a native int type.
4069-
4070-
variants_raw = [
4102+
return [
40714103
(rev_op_name, right_op, right_type, left_expr),
40724104
(op_name, left_op, left_type, right_expr),
40734105
]
4074-
else:
4075-
# In all other cases, we do the usual thing and call __add__ first and
4076-
# __radd__ second when doing "A() + B()".
40774106

4078-
variants_raw = [
4079-
(op_name, left_op, left_type, right_expr),
4080-
(rev_op_name, right_op, right_type, left_expr),
4081-
]
4107+
# In all other cases, we do the usual thing and call __add__ first and
4108+
# __radd__ second when doing "A() + B()".
4109+
return [
4110+
(op_name, left_op, left_type, right_expr),
4111+
(rev_op_name, right_op, right_type, left_expr),
4112+
]
40824113

4083-
# STEP 3:
4084-
# We now filter out all non-existent operators. The 'variants' list contains
4085-
# all operator methods that are actually present, in the order that Python
4086-
# attempts to invoke them.
4087-
4088-
variants = [(na, op, obj, arg) for (na, op, obj, arg) in variants_raw if op is not None]
40894114

4090-
# STEP 4:
4091-
# We now try invoking each one. If an operation succeeds, end early and return
4092-
# the corresponding result. Otherwise, return the result and errors associated
4093-
# with the first entry.
4115+
def _attempt_operator_applications(
4116+
self,
4117+
op_name: str,
4118+
variants: list[tuple[str, Type, Type, Expression]],
4119+
left_type: Type,
4120+
right_type: Type,
4121+
left_expr: Expression,
4122+
right_expr: Expression,
4123+
context: Context,
4124+
) -> tuple[Type, Type]:
4125+
"""Try applying the operator methods and handle possible fallbacks."""
40944126

40954127
errors = []
40964128
results = []
4129+
40974130
for name, method, obj, arg in variants:
40984131
with self.msg.filter_errors(save_filtered_errors=True) as local_errors:
4099-
result = self.check_method_call(name, obj, method, [arg], [ARG_POS], context)
4132+
result = self.check_method_call(op_name, obj, method, [arg], [ARG_POS], context)
4133+
41004134
if local_errors.has_new_errors():
41014135
errors.append(local_errors.filtered_errors())
41024136
results.append(result)
41034137
else:
4138+
obj = get_proper_type(obj)
41044139
if isinstance(obj, Instance) and isinstance(
41054140
defn := obj.type.get_method(name), OverloadedFuncDef
41064141
):
@@ -4113,6 +4148,9 @@ def lookup_definer(typ: Instance, attr_name: str) -> str | None:
41134148
self.chk.check_deprecated(item.func, context)
41144149
return result
41154150

4151+
left_type = get_proper_type(left_type)
4152+
right_type = get_proper_type(right_type)
4153+
41164154
# We finish invoking above operators and no early return happens. Therefore,
41174155
# we check if either the LHS or the RHS is Instance and fallbacks to Any,
41184156
# if so, we also return Any
@@ -4125,13 +4163,11 @@ def lookup_definer(typ: Instance, attr_name: str) -> str | None:
41254163
# STEP 4b:
41264164
# Sometimes, the variants list is empty. In that case, we fall-back to attempting to
41274165
# call the __op__ method (even though it's missing).
4128-
41294166
if not variants:
41304167
with self.msg.filter_errors(save_filtered_errors=True) as local_errors:
41314168
result = self.check_method_call_by_name(
41324169
op_name, left_type, [right_expr], [ARG_POS], context
41334170
)
4134-
41354171
if local_errors.has_new_errors():
41364172
errors.append(local_errors.filtered_errors())
41374173
results.append(result)
@@ -4146,13 +4182,13 @@ def lookup_definer(typ: Instance, attr_name: str) -> str | None:
41464182
# TODO: Remove this extra case
41474183
return result
41484184

4185+
# Return the result and emit the first error
41494186
self.msg.add_errors(errors[0])
41504187
if len(results) == 1:
41514188
return results[0]
41524189
else:
41534190
error_any = AnyType(TypeOfAny.from_error)
4154-
result = error_any, error_any
4155-
return result
4191+
return error_any, error_any
41564192

41574193
def check_op(
41584194
self,

0 commit comments

Comments
 (0)