Skip to content

Commit d9c77cb

Browse files
authored
[mypyc] Refactor IR build of equality and unary operators (#19756)
Make the code cleaner. This will also make it easier to implement additional optimizations. This probably fixes a few bugs as well.
1 parent 6a88c21 commit d9c77cb

File tree

1 file changed

+76
-49
lines changed

1 file changed

+76
-49
lines changed

mypyc/irbuild/ll_builder.py

Lines changed: 76 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,12 +1395,6 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13951395
# Special case various ops
13961396
if op in ("is", "is not"):
13971397
return self.translate_is_op(lreg, rreg, op, line)
1398-
# TODO: modify 'str' to use same interface as 'compare_bytes' as it avoids
1399-
# call to PyErr_Occurred()
1400-
if is_str_rprimitive(ltype) and is_str_rprimitive(rtype) and op in ("==", "!="):
1401-
return self.compare_strings(lreg, rreg, op, line)
1402-
if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ("==", "!="):
1403-
return self.compare_bytes(lreg, rreg, op, line)
14041398
if (
14051399
is_bool_or_bit_rprimitive(ltype)
14061400
and is_bool_or_bit_rprimitive(rtype)
@@ -1496,6 +1490,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
14961490
def dunder_op(self, lreg: Value, rreg: Value | None, op: str, line: int) -> Value | None:
14971491
"""
14981492
Dispatch a dunder method if applicable.
1493+
14991494
For example for `a + b` it will use `a.__add__(b)` which can lead to higher performance
15001495
due to the fact that the method could be already compiled and optimized instead of going
15011496
all the way through `PyNumber_Add(a, b)` python api (making a jump into the python DL).
@@ -1545,6 +1540,10 @@ def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
15451540
elif op == "!=":
15461541
eq = self.primitive_op(str_eq, [lhs, rhs], line)
15471542
return self.add(ComparisonOp(eq, self.false(), ComparisonOp.EQ, line))
1543+
1544+
# TODO: modify 'str' to use same interface as 'compare_bytes' as it would avoid
1545+
# call to PyErr_Occurred() below
1546+
15481547
compare_result = self.call_c(unicode_compare, [lhs, rhs], line)
15491548
error_constant = Integer(-1, c_int_rprimitive, line)
15501549
compare_error_check = self.add(
@@ -1648,55 +1647,75 @@ def bool_comparison_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Va
16481647
op_id = ComparisonOp.signed_ops[op]
16491648
return self.comparison_op(lreg, rreg, op_id, line)
16501649

1651-
def unary_not(self, value: Value, line: int) -> Value:
1652-
mask = Integer(1, value.type, line)
1653-
return self.int_op(value.type, value, mask, IntOp.XOR, line)
1650+
def _non_specialized_unary_op(self, value: Value, op: str, line: int) -> Value:
1651+
if isinstance(value.type, RInstance):
1652+
result = self.dunder_op(value, None, op, line)
1653+
if result is not None:
1654+
return result
1655+
primitive_ops_candidates = unary_ops.get(op, [])
1656+
target = self.matching_primitive_op(primitive_ops_candidates, [value], line)
1657+
assert target, "Unsupported unary operation: %s" % op
1658+
return target
16541659

1655-
def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
1660+
def unary_not(self, value: Value, line: int) -> Value:
1661+
"""Perform unary 'not'."""
16561662
typ = value.type
16571663
if is_bool_or_bit_rprimitive(typ):
1658-
if expr_op == "not":
1659-
return self.unary_not(value, line)
1660-
if expr_op == "+":
1661-
return value
1662-
if is_fixed_width_rtype(typ):
1663-
if expr_op == "-":
1664-
# Translate to '0 - x'
1665-
return self.int_op(typ, Integer(0, typ), value, IntOp.SUB, line)
1666-
elif expr_op == "~":
1667-
if typ.is_signed:
1668-
# Translate to 'x ^ -1'
1669-
return self.int_op(typ, value, Integer(-1, typ), IntOp.XOR, line)
1670-
else:
1671-
# Translate to 'x ^ 0xff...'
1672-
mask = (1 << (typ.size * 8)) - 1
1673-
return self.int_op(typ, value, Integer(mask, typ), IntOp.XOR, line)
1674-
elif expr_op == "+":
1675-
return value
1676-
if is_float_rprimitive(typ):
1677-
if expr_op == "-":
1678-
return self.add(FloatNeg(value, line))
1679-
elif expr_op == "+":
1680-
return value
1664+
mask = Integer(1, typ, line)
1665+
return self.int_op(typ, value, mask, IntOp.XOR, line)
1666+
return self._non_specialized_unary_op(value, "not", line)
16811667

1668+
def unary_minus(self, value: Value, line: int) -> Value:
1669+
"""Perform unary '-'."""
1670+
typ = value.type
16821671
if isinstance(value, Integer):
16831672
# TODO: Overflow? Unsigned?
1684-
num = value.value
1685-
if is_short_int_rprimitive(typ):
1686-
num >>= 1
1687-
return Integer(-num, typ, value.line)
1688-
if is_tagged(typ) and expr_op == "+":
1673+
return Integer(-value.numeric_value(), typ, line)
1674+
elif isinstance(value, Float):
1675+
return Float(-value.value, line)
1676+
elif is_fixed_width_rtype(typ):
1677+
# Translate to '0 - x'
1678+
return self.int_op(typ, Integer(0, typ), value, IntOp.SUB, line)
1679+
elif is_float_rprimitive(typ):
1680+
return self.add(FloatNeg(value, line))
1681+
return self._non_specialized_unary_op(value, "-", line)
1682+
1683+
def unary_plus(self, value: Value, line: int) -> Value:
1684+
"""Perform unary '+'."""
1685+
typ = value.type
1686+
if (
1687+
is_tagged(typ)
1688+
or is_float_rprimitive(typ)
1689+
or is_bool_or_bit_rprimitive(typ)
1690+
or is_fixed_width_rtype(typ)
1691+
):
16891692
return value
1690-
if isinstance(value, Float):
1691-
return Float(-value.value, value.line)
1692-
if isinstance(typ, RInstance):
1693-
result = self.dunder_op(value, None, expr_op, line)
1694-
if result is not None:
1695-
return result
1696-
primitive_ops_candidates = unary_ops.get(expr_op, [])
1697-
target = self.matching_primitive_op(primitive_ops_candidates, [value], line)
1698-
assert target, "Unsupported unary operation: %s" % expr_op
1699-
return target
1693+
return self._non_specialized_unary_op(value, "+", line)
1694+
1695+
def unary_invert(self, value: Value, line: int) -> Value:
1696+
"""Perform unary '~'."""
1697+
typ = value.type
1698+
if is_fixed_width_rtype(typ):
1699+
if typ.is_signed:
1700+
# Translate to 'x ^ -1'
1701+
return self.int_op(typ, value, Integer(-1, typ), IntOp.XOR, line)
1702+
else:
1703+
# Translate to 'x ^ 0xff...'
1704+
mask = (1 << (typ.size * 8)) - 1
1705+
return self.int_op(typ, value, Integer(mask, typ), IntOp.XOR, line)
1706+
return self._non_specialized_unary_op(value, "~", line)
1707+
1708+
def unary_op(self, value: Value, op: str, line: int) -> Value:
1709+
"""Perform a unary operation."""
1710+
if op == "not":
1711+
return self.unary_not(value, line)
1712+
elif op == "-":
1713+
return self.unary_minus(value, line)
1714+
elif op == "+":
1715+
return self.unary_plus(value, line)
1716+
elif op == "~":
1717+
return self.unary_invert(value, line)
1718+
raise RuntimeError("Unsupported unary operation: %s" % op)
17001719

17011720
def make_dict(self, key_value_pairs: Sequence[DictEntry], line: int) -> Value:
17021721
result: Value | None = None
@@ -2480,13 +2499,21 @@ def translate_special_method_call(
24802499
return primitive_op
24812500

24822501
def translate_eq_cmp(self, lreg: Value, rreg: Value, expr_op: str, line: int) -> Value | None:
2483-
"""Add a equality comparison operation.
2502+
"""Add an equality comparison operation.
2503+
2504+
Note that this doesn't cover all possible types.
24842505
24852506
Args:
24862507
expr_op: either '==' or '!='
24872508
"""
24882509
ltype = lreg.type
24892510
rtype = rreg.type
2511+
2512+
if is_str_rprimitive(ltype) and is_str_rprimitive(rtype):
2513+
return self.compare_strings(lreg, rreg, expr_op, line)
2514+
if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype):
2515+
return self.compare_bytes(lreg, rreg, expr_op, line)
2516+
24902517
if not (isinstance(ltype, RInstance) and ltype == rtype):
24912518
return None
24922519

0 commit comments

Comments
 (0)