@@ -1395,12 +1395,6 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13951395 # Special case various ops
13961396 if op in ("is" , "is not" ):
13971397 return self .translate_is_op (lreg , rreg , op , line )
1398- # TODO: modify 'str' to use same interface as 'compare_bytes' as it avoids
1399- # call to PyErr_Occurred()
1400- if is_str_rprimitive (ltype ) and is_str_rprimitive (rtype ) and op in ("==" , "!=" ):
1401- return self .compare_strings (lreg , rreg , op , line )
1402- if is_bytes_rprimitive (ltype ) and is_bytes_rprimitive (rtype ) and op in ("==" , "!=" ):
1403- return self .compare_bytes (lreg , rreg , op , line )
14041398 if (
14051399 is_bool_or_bit_rprimitive (ltype )
14061400 and is_bool_or_bit_rprimitive (rtype )
@@ -1496,6 +1490,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
14961490 def dunder_op (self , lreg : Value , rreg : Value | None , op : str , line : int ) -> Value | None :
14971491 """
14981492 Dispatch a dunder method if applicable.
1493+
14991494 For example for `a + b` it will use `a.__add__(b)` which can lead to higher performance
15001495 due to the fact that the method could be already compiled and optimized instead of going
15011496 all the way through `PyNumber_Add(a, b)` python api (making a jump into the python DL).
@@ -1545,6 +1540,10 @@ def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
15451540 elif op == "!=" :
15461541 eq = self .primitive_op (str_eq , [lhs , rhs ], line )
15471542 return self .add (ComparisonOp (eq , self .false (), ComparisonOp .EQ , line ))
1543+
1544+ # TODO: modify 'str' to use same interface as 'compare_bytes' as it would avoid
1545+ # call to PyErr_Occurred() below
1546+
15481547 compare_result = self .call_c (unicode_compare , [lhs , rhs ], line )
15491548 error_constant = Integer (- 1 , c_int_rprimitive , line )
15501549 compare_error_check = self .add (
@@ -1648,55 +1647,75 @@ def bool_comparison_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Va
16481647 op_id = ComparisonOp .signed_ops [op ]
16491648 return self .comparison_op (lreg , rreg , op_id , line )
16501649
1651- def unary_not (self , value : Value , line : int ) -> Value :
1652- mask = Integer (1 , value .type , line )
1653- return self .int_op (value .type , value , mask , IntOp .XOR , line )
1650+ def _non_specialized_unary_op (self , value : Value , op : str , line : int ) -> Value :
1651+ if isinstance (value .type , RInstance ):
1652+ result = self .dunder_op (value , None , op , line )
1653+ if result is not None :
1654+ return result
1655+ primitive_ops_candidates = unary_ops .get (op , [])
1656+ target = self .matching_primitive_op (primitive_ops_candidates , [value ], line )
1657+ assert target , "Unsupported unary operation: %s" % op
1658+ return target
16541659
1655- def unary_op (self , value : Value , expr_op : str , line : int ) -> Value :
1660+ def unary_not (self , value : Value , line : int ) -> Value :
1661+ """Perform unary 'not'."""
16561662 typ = value .type
16571663 if is_bool_or_bit_rprimitive (typ ):
1658- if expr_op == "not" :
1659- return self .unary_not (value , line )
1660- if expr_op == "+" :
1661- return value
1662- if is_fixed_width_rtype (typ ):
1663- if expr_op == "-" :
1664- # Translate to '0 - x'
1665- return self .int_op (typ , Integer (0 , typ ), value , IntOp .SUB , line )
1666- elif expr_op == "~" :
1667- if typ .is_signed :
1668- # Translate to 'x ^ -1'
1669- return self .int_op (typ , value , Integer (- 1 , typ ), IntOp .XOR , line )
1670- else :
1671- # Translate to 'x ^ 0xff...'
1672- mask = (1 << (typ .size * 8 )) - 1
1673- return self .int_op (typ , value , Integer (mask , typ ), IntOp .XOR , line )
1674- elif expr_op == "+" :
1675- return value
1676- if is_float_rprimitive (typ ):
1677- if expr_op == "-" :
1678- return self .add (FloatNeg (value , line ))
1679- elif expr_op == "+" :
1680- return value
1664+ mask = Integer (1 , typ , line )
1665+ return self .int_op (typ , value , mask , IntOp .XOR , line )
1666+ return self ._non_specialized_unary_op (value , "not" , line )
16811667
1668+ def unary_minus (self , value : Value , line : int ) -> Value :
1669+ """Perform unary '-'."""
1670+ typ = value .type
16821671 if isinstance (value , Integer ):
16831672 # TODO: Overflow? Unsigned?
1684- num = value .value
1685- if is_short_int_rprimitive (typ ):
1686- num >>= 1
1687- return Integer (- num , typ , value .line )
1688- if is_tagged (typ ) and expr_op == "+" :
1673+ return Integer (- value .numeric_value (), typ , line )
1674+ elif isinstance (value , Float ):
1675+ return Float (- value .value , line )
1676+ elif is_fixed_width_rtype (typ ):
1677+ # Translate to '0 - x'
1678+ return self .int_op (typ , Integer (0 , typ ), value , IntOp .SUB , line )
1679+ elif is_float_rprimitive (typ ):
1680+ return self .add (FloatNeg (value , line ))
1681+ return self ._non_specialized_unary_op (value , "-" , line )
1682+
1683+ def unary_plus (self , value : Value , line : int ) -> Value :
1684+ """Perform unary '+'."""
1685+ typ = value .type
1686+ if (
1687+ is_tagged (typ )
1688+ or is_float_rprimitive (typ )
1689+ or is_bool_or_bit_rprimitive (typ )
1690+ or is_fixed_width_rtype (typ )
1691+ ):
16891692 return value
1690- if isinstance (value , Float ):
1691- return Float (- value .value , value .line )
1692- if isinstance (typ , RInstance ):
1693- result = self .dunder_op (value , None , expr_op , line )
1694- if result is not None :
1695- return result
1696- primitive_ops_candidates = unary_ops .get (expr_op , [])
1697- target = self .matching_primitive_op (primitive_ops_candidates , [value ], line )
1698- assert target , "Unsupported unary operation: %s" % expr_op
1699- return target
1693+ return self ._non_specialized_unary_op (value , "+" , line )
1694+
1695+ def unary_invert (self , value : Value , line : int ) -> Value :
1696+ """Perform unary '~'."""
1697+ typ = value .type
1698+ if is_fixed_width_rtype (typ ):
1699+ if typ .is_signed :
1700+ # Translate to 'x ^ -1'
1701+ return self .int_op (typ , value , Integer (- 1 , typ ), IntOp .XOR , line )
1702+ else :
1703+ # Translate to 'x ^ 0xff...'
1704+ mask = (1 << (typ .size * 8 )) - 1
1705+ return self .int_op (typ , value , Integer (mask , typ ), IntOp .XOR , line )
1706+ return self ._non_specialized_unary_op (value , "~" , line )
1707+
1708+ def unary_op (self , value : Value , op : str , line : int ) -> Value :
1709+ """Perform a unary operation."""
1710+ if op == "not" :
1711+ return self .unary_not (value , line )
1712+ elif op == "-" :
1713+ return self .unary_minus (value , line )
1714+ elif op == "+" :
1715+ return self .unary_plus (value , line )
1716+ elif op == "~" :
1717+ return self .unary_invert (value , line )
1718+ raise RuntimeError ("Unsupported unary operation: %s" % op )
17001719
17011720 def make_dict (self , key_value_pairs : Sequence [DictEntry ], line : int ) -> Value :
17021721 result : Value | None = None
@@ -2480,13 +2499,21 @@ def translate_special_method_call(
24802499 return primitive_op
24812500
24822501 def translate_eq_cmp (self , lreg : Value , rreg : Value , expr_op : str , line : int ) -> Value | None :
2483- """Add a equality comparison operation.
2502+ """Add an equality comparison operation.
2503+
2504+ Note that this doesn't cover all possible types.
24842505
24852506 Args:
24862507 expr_op: either '==' or '!='
24872508 """
24882509 ltype = lreg .type
24892510 rtype = rreg .type
2511+
2512+ if is_str_rprimitive (ltype ) and is_str_rprimitive (rtype ):
2513+ return self .compare_strings (lreg , rreg , expr_op , line )
2514+ if is_bytes_rprimitive (ltype ) and is_bytes_rprimitive (rtype ):
2515+ return self .compare_bytes (lreg , rreg , expr_op , line )
2516+
24902517 if not (isinstance (ltype , RInstance ) and ltype == rtype ):
24912518 return None
24922519
0 commit comments