9393 dict_rprimitive ,
9494 float_rprimitive ,
9595 int_rprimitive ,
96- is_bit_rprimitive ,
97- is_bool_rprimitive ,
96+ is_bool_or_bit_rprimitive ,
9897 is_bytes_rprimitive ,
9998 is_dict_rprimitive ,
10099 is_fixed_width_rtype ,
175174 unary_ops ,
176175)
177176from mypyc .primitives .set_ops import new_set_op
178- from mypyc .primitives .str_ops import str_check_if_true , str_ssize_t_size_op , unicode_compare
177+ from mypyc .primitives .str_ops import (
178+ str_check_if_true ,
179+ str_eq ,
180+ str_ssize_t_size_op ,
181+ unicode_compare ,
182+ )
179183from mypyc .primitives .tuple_ops import list_tuple_op , new_tuple_op , new_tuple_with_length_op
180184from mypyc .rt_subtype import is_runtime_subtype
181185from mypyc .sametype import is_same_type
@@ -376,16 +380,12 @@ def coerce(
376380 ):
377381 # Equivalent types
378382 return src
379- elif (is_bool_rprimitive (src_type ) or is_bit_rprimitive (src_type )) and is_tagged (
380- target_type
381- ):
383+ elif is_bool_or_bit_rprimitive (src_type ) and is_tagged (target_type ):
382384 shifted = self .int_op (
383385 bool_rprimitive , src , Integer (1 , bool_rprimitive ), IntOp .LEFT_SHIFT
384386 )
385387 return self .add (Extend (shifted , target_type , signed = False ))
386- elif (
387- is_bool_rprimitive (src_type ) or is_bit_rprimitive (src_type )
388- ) and is_fixed_width_rtype (target_type ):
388+ elif is_bool_or_bit_rprimitive (src_type ) and is_fixed_width_rtype (target_type ):
389389 return self .add (Extend (src , target_type , signed = False ))
390390 elif isinstance (src , Integer ) and is_float_rprimitive (target_type ):
391391 if is_tagged (src_type ):
@@ -1336,7 +1336,11 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13361336 return self .compare_strings (lreg , rreg , op , line )
13371337 if is_bytes_rprimitive (ltype ) and is_bytes_rprimitive (rtype ) and op in ("==" , "!=" ):
13381338 return self .compare_bytes (lreg , rreg , op , line )
1339- if is_bool_rprimitive (ltype ) and is_bool_rprimitive (rtype ) and op in BOOL_BINARY_OPS :
1339+ if (
1340+ is_bool_or_bit_rprimitive (ltype )
1341+ and is_bool_or_bit_rprimitive (rtype )
1342+ and op in BOOL_BINARY_OPS
1343+ ):
13401344 if op in ComparisonOp .signed_ops :
13411345 return self .bool_comparison_op (lreg , rreg , op , line )
13421346 else :
@@ -1350,7 +1354,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13501354 op_id = int_op_to_id [op ]
13511355 else :
13521356 op_id = IntOp .DIV
1353- if is_bool_rprimitive ( rtype ) or is_bit_rprimitive (rtype ):
1357+ if is_bool_or_bit_rprimitive (rtype ):
13541358 rreg = self .coerce (rreg , ltype , line )
13551359 rtype = ltype
13561360 if is_fixed_width_rtype (rtype ) or is_tagged (rtype ):
@@ -1362,7 +1366,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13621366 elif op in ComparisonOp .signed_ops :
13631367 if is_int_rprimitive (rtype ):
13641368 rreg = self .coerce_int_to_fixed_width (rreg , ltype , line )
1365- elif is_bool_rprimitive ( rtype ) or is_bit_rprimitive (rtype ):
1369+ elif is_bool_or_bit_rprimitive (rtype ):
13661370 rreg = self .coerce (rreg , ltype , line )
13671371 op_id = ComparisonOp .signed_ops [op ]
13681372 if is_fixed_width_rtype (rreg .type ):
@@ -1382,13 +1386,13 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13821386 )
13831387 if is_tagged (ltype ):
13841388 return self .fixed_width_int_op (rtype , lreg , rreg , op_id , line )
1385- if is_bool_rprimitive ( ltype ) or is_bit_rprimitive (ltype ):
1389+ if is_bool_or_bit_rprimitive (ltype ):
13861390 lreg = self .coerce (lreg , rtype , line )
13871391 return self .fixed_width_int_op (rtype , lreg , rreg , op_id , line )
13881392 elif op in ComparisonOp .signed_ops :
13891393 if is_int_rprimitive (ltype ):
13901394 lreg = self .coerce_int_to_fixed_width (lreg , rtype , line )
1391- elif is_bool_rprimitive ( ltype ) or is_bit_rprimitive (ltype ):
1395+ elif is_bool_or_bit_rprimitive (ltype ):
13921396 lreg = self .coerce (lreg , rtype , line )
13931397 op_id = ComparisonOp .signed_ops [op ]
13941398 if isinstance (lreg , Integer ):
@@ -1471,6 +1475,11 @@ def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -
14711475
14721476 def compare_strings (self , lhs : Value , rhs : Value , op : str , line : int ) -> Value :
14731477 """Compare two strings"""
1478+ if op == "==" :
1479+ return self .primitive_op (str_eq , [lhs , rhs ], line )
1480+ elif op == "!=" :
1481+ eq = self .primitive_op (str_eq , [lhs , rhs ], line )
1482+ return self .add (ComparisonOp (eq , self .false (), ComparisonOp .EQ , line ))
14741483 compare_result = self .call_c (unicode_compare , [lhs , rhs ], line )
14751484 error_constant = Integer (- 1 , c_int_rprimitive , line )
14761485 compare_error_check = self .add (
@@ -1534,7 +1543,7 @@ def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Val
15341543 compare = self .binary_op (lhs_item , rhs_item , op , line )
15351544 # Cast to bool if necessary since most types uses comparison returning a object type
15361545 # See generic_ops.py for more information
1537- if not ( is_bool_rprimitive ( compare .type ) or is_bit_rprimitive ( compare . type ) ):
1546+ if not is_bool_or_bit_rprimitive ( compare .type ):
15381547 compare = self .primitive_op (bool_op , [compare ], line )
15391548 if i < len (lhs .type .types ) - 1 :
15401549 branch = Branch (compare , early_stop , check_blocks [i + 1 ], Branch .BOOL )
@@ -1553,7 +1562,7 @@ def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Val
15531562
15541563 def translate_instance_contains (self , inst : Value , item : Value , op : str , line : int ) -> Value :
15551564 res = self .gen_method_call (inst , "__contains__" , [item ], None , line )
1556- if not is_bool_rprimitive (res .type ):
1565+ if not is_bool_or_bit_rprimitive (res .type ):
15571566 res = self .primitive_op (bool_op , [res ], line )
15581567 if op == "not in" :
15591568 res = self .bool_bitwise_op (res , Integer (1 , rtype = bool_rprimitive ), "^" , line )
@@ -1580,7 +1589,7 @@ def unary_not(self, value: Value, line: int) -> Value:
15801589
15811590 def unary_op (self , value : Value , expr_op : str , line : int ) -> Value :
15821591 typ = value .type
1583- if is_bool_rprimitive ( typ ) or is_bit_rprimitive (typ ):
1592+ if is_bool_or_bit_rprimitive (typ ):
15841593 if expr_op == "not" :
15851594 return self .unary_not (value , line )
15861595 if expr_op == "+" :
@@ -1738,7 +1747,7 @@ def bool_value(self, value: Value) -> Value:
17381747
17391748 The result type can be bit_rprimitive or bool_rprimitive.
17401749 """
1741- if is_bool_rprimitive ( value . type ) or is_bit_rprimitive (value .type ):
1750+ if is_bool_or_bit_rprimitive (value .type ):
17421751 result = value
17431752 elif is_runtime_subtype (value .type , int_rprimitive ):
17441753 zero = Integer (0 , short_int_rprimitive )
0 commit comments