Skip to content
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 76 additions & 49 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,12 +1395,6 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
# Special case various ops
if op in ("is", "is not"):
return self.translate_is_op(lreg, rreg, op, line)
# TODO: modify 'str' to use same interface as 'compare_bytes' as it avoids
# call to PyErr_Occurred()
if is_str_rprimitive(ltype) and is_str_rprimitive(rtype) and op in ("==", "!="):
return self.compare_strings(lreg, rreg, op, line)
if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ("==", "!="):
return self.compare_bytes(lreg, rreg, op, line)
if (
is_bool_or_bit_rprimitive(ltype)
and is_bool_or_bit_rprimitive(rtype)
Expand Down Expand Up @@ -1496,6 +1490,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
def dunder_op(self, lreg: Value, rreg: Value | None, op: str, line: int) -> Value | None:
"""
Dispatch a dunder method if applicable.

For example for `a + b` it will use `a.__add__(b)` which can lead to higher performance
due to the fact that the method could be already compiled and optimized instead of going
all the way through `PyNumber_Add(a, b)` python api (making a jump into the python DL).
Expand Down Expand Up @@ -1545,6 +1540,10 @@ def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
elif op == "!=":
eq = self.primitive_op(str_eq, [lhs, rhs], line)
return self.add(ComparisonOp(eq, self.false(), ComparisonOp.EQ, line))

# TODO: modify 'str' to use same interface as 'compare_bytes' as it would avoid
# call to PyErr_Occurred() below

compare_result = self.call_c(unicode_compare, [lhs, rhs], line)
error_constant = Integer(-1, c_int_rprimitive, line)
compare_error_check = self.add(
Expand Down Expand Up @@ -1648,55 +1647,75 @@ def bool_comparison_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Va
op_id = ComparisonOp.signed_ops[op]
return self.comparison_op(lreg, rreg, op_id, line)

def unary_not(self, value: Value, line: int) -> Value:
mask = Integer(1, value.type, line)
return self.int_op(value.type, value, mask, IntOp.XOR, line)
def _non_specialized_unary_op(self, value: Value, op: str, line: int) -> Value:
if isinstance(value.type, RInstance):
result = self.dunder_op(value, None, op, line)
if result is not None:
return result
primitive_ops_candidates = unary_ops.get(op, [])
target = self.matching_primitive_op(primitive_ops_candidates, [value], line)
assert target, "Unsupported unary operation: %s" % op
return target

def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
def unary_not(self, value: Value, line: int) -> Value:
"""Perform unary 'not'."""
typ = value.type
if is_bool_or_bit_rprimitive(typ):
if expr_op == "not":
return self.unary_not(value, line)
if expr_op == "+":
return value
if is_fixed_width_rtype(typ):
if expr_op == "-":
# Translate to '0 - x'
return self.int_op(typ, Integer(0, typ), value, IntOp.SUB, line)
elif expr_op == "~":
if typ.is_signed:
# Translate to 'x ^ -1'
return self.int_op(typ, value, Integer(-1, typ), IntOp.XOR, line)
else:
# Translate to 'x ^ 0xff...'
mask = (1 << (typ.size * 8)) - 1
return self.int_op(typ, value, Integer(mask, typ), IntOp.XOR, line)
elif expr_op == "+":
return value
if is_float_rprimitive(typ):
if expr_op == "-":
return self.add(FloatNeg(value, line))
elif expr_op == "+":
return value
mask = Integer(1, value.type, line)
return self.int_op(value.type, value, mask, IntOp.XOR, line)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could use typ in place of value.type here as well.

return self._non_specialized_unary_op(value, "not", line)

def unary_minus(self, value: Value, line: int) -> Value:
"""Perform unary '-'."""
typ = value.type
if isinstance(value, Integer):
# TODO: Overflow? Unsigned?
num = value.value
if is_short_int_rprimitive(typ):
num >>= 1
return Integer(-num, typ, value.line)
if is_tagged(typ) and expr_op == "+":
return Integer(-value.numeric_value(), typ, line)
elif isinstance(value, Float):
return Float(-value.value, line)
elif is_fixed_width_rtype(typ):
# Translate to '0 - x'
return self.int_op(typ, Integer(0, typ), value, IntOp.SUB, line)
elif is_float_rprimitive(typ):
return self.add(FloatNeg(value, line))
return self._non_specialized_unary_op(value, "-", line)

def unary_plus(self, value: Value, line: int) -> Value:
"""Perform unary '+'."""
typ = value.type
if (
is_tagged(typ)
or is_float_rprimitive(typ)
or is_bool_or_bit_rprimitive(typ)
or is_fixed_width_rtype(typ)
):
return value
if isinstance(value, Float):
return Float(-value.value, value.line)
if isinstance(typ, RInstance):
result = self.dunder_op(value, None, expr_op, line)
if result is not None:
return result
primitive_ops_candidates = unary_ops.get(expr_op, [])
target = self.matching_primitive_op(primitive_ops_candidates, [value], line)
assert target, "Unsupported unary operation: %s" % expr_op
return target
return self._non_specialized_unary_op(value, "+", line)

def unary_invert(self, value: Value, line: int) -> Value:
"""Perform unary '~'."""
typ = value.type
if is_fixed_width_rtype(typ):
if typ.is_signed:
# Translate to 'x ^ -1'
return self.int_op(typ, value, Integer(-1, typ), IntOp.XOR, line)
else:
# Translate to 'x ^ 0xff...'
mask = (1 << (typ.size * 8)) - 1
return self.int_op(typ, value, Integer(mask, typ), IntOp.XOR, line)
return self._non_specialized_unary_op(value, "~", line)

def unary_op(self, value: Value, op: str, line: int) -> Value:
"""Perform a unary operation."""
if op == "not":
return self.unary_not(value, line)
elif op == "-":
return self.unary_minus(value, line)
elif op == "+":
return self.unary_plus(value, line)
elif op == "~":
return self.unary_invert(value, line)
raise RuntimeError("Unsupported unary operation: %s" % op)

def make_dict(self, key_value_pairs: Sequence[DictEntry], line: int) -> Value:
result: Value | None = None
Expand Down Expand Up @@ -2480,13 +2499,21 @@ def translate_special_method_call(
return primitive_op

def translate_eq_cmp(self, lreg: Value, rreg: Value, expr_op: str, line: int) -> Value | None:
"""Add a equality comparison operation.
"""Add an equality comparison operation.

Note that this doesn't cover all possible types.

Args:
expr_op: either '==' or '!='
"""
ltype = lreg.type
rtype = rreg.type

if is_str_rprimitive(ltype) and is_str_rprimitive(rtype):
return self.compare_strings(lreg, rreg, expr_op, line)
if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype):
return self.compare_bytes(lreg, rreg, expr_op, line)

if not (isinstance(ltype, RInstance) and ltype == rtype):
return None

Expand Down
Loading