|
14 | 14 |
|
15 | 15 | from mypy.argmap import map_actuals_to_formals |
16 | 16 | from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind |
17 | | -from mypy.operators import op_methods |
| 17 | +from mypy.operators import op_methods, unary_op_methods |
18 | 18 | from mypy.types import AnyType, TypeOfAny |
19 | 19 | from mypyc.common import ( |
20 | 20 | BITMAP_BITS, |
|
167 | 167 | buf_init_item, |
168 | 168 | fast_isinstance_op, |
169 | 169 | none_object_op, |
| 170 | + not_implemented_op, |
170 | 171 | var_object_size, |
171 | 172 | ) |
172 | 173 | from mypyc.primitives.registry import ( |
@@ -1398,11 +1399,48 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: |
1398 | 1399 | if base_op in float_op_to_id: |
1399 | 1400 | return self.float_op(lreg, rreg, base_op, line) |
1400 | 1401 |
|
| 1402 | + dunder_op = self.dunder_op(lreg, rreg, op, line) |
| 1403 | + if dunder_op: |
| 1404 | + return dunder_op |
| 1405 | + |
1401 | 1406 | primitive_ops_candidates = binary_ops.get(op, []) |
1402 | 1407 | target = self.matching_primitive_op(primitive_ops_candidates, [lreg, rreg], line) |
1403 | 1408 | assert target, "Unsupported binary operation: %s" % op |
1404 | 1409 | return target |
1405 | 1410 |
|
| 1411 | + def dunder_op(self, lreg: Value, rreg: Value | None, op: str, line: int) -> Value | None: |
| 1412 | + """ |
| 1413 | + Dispatch a dunder method if applicable. |
| 1414 | + For example for `a + b` it will use `a.__add__(b)` which can lead to higher performance |
| 1415 | + due to the fact that the method could be already compiled and optimized instead of going |
| 1416 | + all the way through `PyNumber_Add(a, b)` python api (making a jump into the python DL). |
| 1417 | + """ |
| 1418 | + ltype = lreg.type |
| 1419 | + if not isinstance(ltype, RInstance): |
| 1420 | + return None |
| 1421 | + |
| 1422 | + method_name = op_methods.get(op) if rreg else unary_op_methods.get(op) |
| 1423 | + if method_name is None: |
| 1424 | + return None |
| 1425 | + |
| 1426 | + if not ltype.class_ir.has_method(method_name): |
| 1427 | + return None |
| 1428 | + |
| 1429 | + decl = ltype.class_ir.method_decl(method_name) |
| 1430 | + if not rreg and len(decl.sig.args) != 1: |
| 1431 | + return None |
| 1432 | + |
| 1433 | + if rreg and (len(decl.sig.args) != 2 or not is_subtype(rreg.type, decl.sig.args[1].type)): |
| 1434 | + return None |
| 1435 | + |
| 1436 | + if rreg and is_subtype(not_implemented_op.type, decl.sig.ret_type): |
| 1437 | + # If the method is able to return NotImplemented, we should not optimize it. |
| 1438 | + # We can just let go so it will be handled through the python api. |
| 1439 | + return None |
| 1440 | + |
| 1441 | + args = [rreg] if rreg else [] |
| 1442 | + return self.gen_method_call(lreg, method_name, args, decl.sig.ret_type, line) |
| 1443 | + |
1406 | 1444 | def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -> Value: |
1407 | 1445 | """Check if a tagged integer is a short integer. |
1408 | 1446 |
|
@@ -1558,16 +1596,9 @@ def unary_op(self, value: Value, expr_op: str, line: int) -> Value: |
1558 | 1596 | if isinstance(value, Float): |
1559 | 1597 | return Float(-value.value, value.line) |
1560 | 1598 | if isinstance(typ, RInstance): |
1561 | | - if expr_op == "-": |
1562 | | - method = "__neg__" |
1563 | | - elif expr_op == "+": |
1564 | | - method = "__pos__" |
1565 | | - elif expr_op == "~": |
1566 | | - method = "__invert__" |
1567 | | - else: |
1568 | | - method = "" |
1569 | | - if method and typ.class_ir.has_method(method): |
1570 | | - return self.gen_method_call(value, method, [], None, line) |
| 1599 | + result = self.dunder_op(value, None, expr_op, line) |
| 1600 | + if result is not None: |
| 1601 | + return result |
1571 | 1602 | call_c_ops_candidates = unary_ops.get(expr_op, []) |
1572 | 1603 | target = self.matching_call_c(call_c_ops_candidates, [value], line) |
1573 | 1604 | assert target, "Unsupported unary operation: %s" % expr_op |
|
0 commit comments