93
93
dict_rprimitive ,
94
94
float_rprimitive ,
95
95
int_rprimitive ,
96
- is_bit_rprimitive ,
97
- is_bool_rprimitive ,
96
+ is_bool_or_bit_rprimitive ,
98
97
is_bytes_rprimitive ,
99
98
is_dict_rprimitive ,
100
99
is_fixed_width_rtype ,
175
174
unary_ops ,
176
175
)
177
176
from mypyc .primitives .set_ops import new_set_op
178
- from mypyc .primitives .str_ops import str_check_if_true , str_ssize_t_size_op , unicode_compare
177
+ from mypyc .primitives .str_ops import (
178
+ str_check_if_true ,
179
+ str_eq ,
180
+ str_ssize_t_size_op ,
181
+ unicode_compare ,
182
+ )
179
183
from mypyc .primitives .tuple_ops import list_tuple_op , new_tuple_op , new_tuple_with_length_op
180
184
from mypyc .rt_subtype import is_runtime_subtype
181
185
from mypyc .sametype import is_same_type
@@ -376,16 +380,12 @@ def coerce(
376
380
):
377
381
# Equivalent types
378
382
return src
379
- elif (is_bool_rprimitive (src_type ) or is_bit_rprimitive (src_type )) and is_tagged (
380
- target_type
381
- ):
383
+ elif is_bool_or_bit_rprimitive (src_type ) and is_tagged (target_type ):
382
384
shifted = self .int_op (
383
385
bool_rprimitive , src , Integer (1 , bool_rprimitive ), IntOp .LEFT_SHIFT
384
386
)
385
387
return self .add (Extend (shifted , target_type , signed = False ))
386
- elif (
387
- is_bool_rprimitive (src_type ) or is_bit_rprimitive (src_type )
388
- ) and is_fixed_width_rtype (target_type ):
388
+ elif is_bool_or_bit_rprimitive (src_type ) and is_fixed_width_rtype (target_type ):
389
389
return self .add (Extend (src , target_type , signed = False ))
390
390
elif isinstance (src , Integer ) and is_float_rprimitive (target_type ):
391
391
if is_tagged (src_type ):
@@ -1337,7 +1337,11 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
1337
1337
return self .compare_strings (lreg , rreg , op , line )
1338
1338
if is_bytes_rprimitive (ltype ) and is_bytes_rprimitive (rtype ) and op in ("==" , "!=" ):
1339
1339
return self .compare_bytes (lreg , rreg , op , line )
1340
- if is_bool_rprimitive (ltype ) and is_bool_rprimitive (rtype ) and op in BOOL_BINARY_OPS :
1340
+ if (
1341
+ is_bool_or_bit_rprimitive (ltype )
1342
+ and is_bool_or_bit_rprimitive (rtype )
1343
+ and op in BOOL_BINARY_OPS
1344
+ ):
1341
1345
if op in ComparisonOp .signed_ops :
1342
1346
return self .bool_comparison_op (lreg , rreg , op , line )
1343
1347
else :
@@ -1351,7 +1355,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
1351
1355
op_id = int_op_to_id [op ]
1352
1356
else :
1353
1357
op_id = IntOp .DIV
1354
- if is_bool_rprimitive ( rtype ) or is_bit_rprimitive (rtype ):
1358
+ if is_bool_or_bit_rprimitive (rtype ):
1355
1359
rreg = self .coerce (rreg , ltype , line )
1356
1360
rtype = ltype
1357
1361
if is_fixed_width_rtype (rtype ) or is_tagged (rtype ):
@@ -1363,7 +1367,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
1363
1367
elif op in ComparisonOp .signed_ops :
1364
1368
if is_int_rprimitive (rtype ):
1365
1369
rreg = self .coerce_int_to_fixed_width (rreg , ltype , line )
1366
- elif is_bool_rprimitive ( rtype ) or is_bit_rprimitive (rtype ):
1370
+ elif is_bool_or_bit_rprimitive (rtype ):
1367
1371
rreg = self .coerce (rreg , ltype , line )
1368
1372
op_id = ComparisonOp .signed_ops [op ]
1369
1373
if is_fixed_width_rtype (rreg .type ):
@@ -1383,13 +1387,13 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
1383
1387
)
1384
1388
if is_tagged (ltype ):
1385
1389
return self .fixed_width_int_op (rtype , lreg , rreg , op_id , line )
1386
- if is_bool_rprimitive ( ltype ) or is_bit_rprimitive (ltype ):
1390
+ if is_bool_or_bit_rprimitive (ltype ):
1387
1391
lreg = self .coerce (lreg , rtype , line )
1388
1392
return self .fixed_width_int_op (rtype , lreg , rreg , op_id , line )
1389
1393
elif op in ComparisonOp .signed_ops :
1390
1394
if is_int_rprimitive (ltype ):
1391
1395
lreg = self .coerce_int_to_fixed_width (lreg , rtype , line )
1392
- elif is_bool_rprimitive ( ltype ) or is_bit_rprimitive (ltype ):
1396
+ elif is_bool_or_bit_rprimitive (ltype ):
1393
1397
lreg = self .coerce (lreg , rtype , line )
1394
1398
op_id = ComparisonOp .signed_ops [op ]
1395
1399
if isinstance (lreg , Integer ):
@@ -1472,6 +1476,11 @@ def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -
1472
1476
1473
1477
def compare_strings (self , lhs : Value , rhs : Value , op : str , line : int ) -> Value :
1474
1478
"""Compare two strings"""
1479
+ if op == "==" :
1480
+ return self .primitive_op (str_eq , [lhs , rhs ], line )
1481
+ elif op == "!=" :
1482
+ eq = self .primitive_op (str_eq , [lhs , rhs ], line )
1483
+ return self .add (ComparisonOp (eq , self .false (), ComparisonOp .EQ , line ))
1475
1484
compare_result = self .call_c (unicode_compare , [lhs , rhs ], line )
1476
1485
error_constant = Integer (- 1 , c_int_rprimitive , line )
1477
1486
compare_error_check = self .add (
@@ -1535,7 +1544,7 @@ def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Val
1535
1544
compare = self .binary_op (lhs_item , rhs_item , op , line )
1536
1545
# Cast to bool if necessary since most types uses comparison returning a object type
1537
1546
# See generic_ops.py for more information
1538
- if not ( is_bool_rprimitive ( compare .type ) or is_bit_rprimitive ( compare . type ) ):
1547
+ if not is_bool_or_bit_rprimitive ( compare .type ):
1539
1548
compare = self .primitive_op (bool_op , [compare ], line )
1540
1549
if i < len (lhs .type .types ) - 1 :
1541
1550
branch = Branch (compare , early_stop , check_blocks [i + 1 ], Branch .BOOL )
@@ -1554,7 +1563,7 @@ def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Val
1554
1563
1555
1564
def translate_instance_contains (self , inst : Value , item : Value , op : str , line : int ) -> Value :
1556
1565
res = self .gen_method_call (inst , "__contains__" , [item ], None , line )
1557
- if not is_bool_rprimitive (res .type ):
1566
+ if not is_bool_or_bit_rprimitive (res .type ):
1558
1567
res = self .primitive_op (bool_op , [res ], line )
1559
1568
if op == "not in" :
1560
1569
res = self .bool_bitwise_op (res , Integer (1 , rtype = bool_rprimitive ), "^" , line )
@@ -1581,7 +1590,7 @@ def unary_not(self, value: Value, line: int) -> Value:
1581
1590
1582
1591
def unary_op (self , value : Value , expr_op : str , line : int ) -> Value :
1583
1592
typ = value .type
1584
- if is_bool_rprimitive ( typ ) or is_bit_rprimitive (typ ):
1593
+ if is_bool_or_bit_rprimitive (typ ):
1585
1594
if expr_op == "not" :
1586
1595
return self .unary_not (value , line )
1587
1596
if expr_op == "+" :
@@ -1739,7 +1748,7 @@ def bool_value(self, value: Value) -> Value:
1739
1748
1740
1749
The result type can be bit_rprimitive or bool_rprimitive.
1741
1750
"""
1742
- if is_bool_rprimitive ( value . type ) or is_bit_rprimitive (value .type ):
1751
+ if is_bool_or_bit_rprimitive (value .type ):
1743
1752
result = value
1744
1753
elif is_runtime_subtype (value .type , int_rprimitive ):
1745
1754
zero = Integer (0 , short_int_rprimitive )
0 commit comments