Skip to content

Commit 6eb3eaa

Browse files
committed
Add is_bool_or_bit_rprimitive
1 parent 503f5bd commit 6eb3eaa

File tree

5 files changed

+78
-31
lines changed

5 files changed

+78
-31
lines changed

mypyc/codegen/emit.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
RType,
2929
RUnion,
3030
int_rprimitive,
31-
is_bit_rprimitive,
32-
is_bool_rprimitive,
31+
is_bool_or_bit_rprimitive,
3332
is_bytes_rprimitive,
3433
is_dict_rprimitive,
3534
is_fixed_width_rtype,
@@ -615,8 +614,7 @@ def emit_cast(
615614
or is_range_rprimitive(typ)
616615
or is_float_rprimitive(typ)
617616
or is_int_rprimitive(typ)
618-
or is_bool_rprimitive(typ)
619-
or is_bit_rprimitive(typ)
617+
or is_bool_or_bit_rprimitive(typ)
620618
or is_fixed_width_rtype(typ)
621619
):
622620
if declare_dest:
@@ -638,7 +636,7 @@ def emit_cast(
638636
elif is_int_rprimitive(typ) or is_fixed_width_rtype(typ):
639637
# TODO: Range check for fixed-width types?
640638
prefix = "PyLong"
641-
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
639+
elif is_bool_or_bit_rprimitive(typ):
642640
prefix = "PyBool"
643641
else:
644642
assert False, f"unexpected primitive type: {typ}"
@@ -889,7 +887,7 @@ def emit_unbox(
889887
self.emit_line("else {")
890888
self.emit_line(failure)
891889
self.emit_line("}")
892-
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
890+
elif is_bool_or_bit_rprimitive(typ):
893891
# Whether we are borrowing or not makes no difference.
894892
if declare_dest:
895893
self.emit_line(f"char {dest};")
@@ -1015,7 +1013,7 @@ def emit_box(
10151013
if is_int_rprimitive(typ) or is_short_int_rprimitive(typ):
10161014
# Steal the existing reference if it exists.
10171015
self.emit_line(f"{declaration}{dest} = CPyTagged_StealAsObject({src});")
1018-
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
1016+
elif is_bool_or_bit_rprimitive(typ):
10191017
# N.B: bool is special cased to produce a borrowed value
10201018
# after boxing, so we don't need to increment the refcount
10211019
# when this comes directly from a Box op.

mypyc/ir/ops.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ class to enable the new behavior. Sometimes adding a new abstract
4242
cstring_rprimitive,
4343
float_rprimitive,
4444
int_rprimitive,
45-
is_bit_rprimitive,
46-
is_bool_rprimitive,
45+
is_bool_or_bit_rprimitive,
4746
is_int_rprimitive,
4847
is_none_rprimitive,
4948
is_pointer_rprimitive,
@@ -1089,11 +1088,7 @@ def __init__(self, src: Value, line: int = -1) -> None:
10891088
self.src = src
10901089
self.type = object_rprimitive
10911090
# When we box None and bool values, we produce a borrowed result
1092-
if (
1093-
is_none_rprimitive(self.src.type)
1094-
or is_bool_rprimitive(self.src.type)
1095-
or is_bit_rprimitive(self.src.type)
1096-
):
1091+
if is_none_rprimitive(self.src.type) or is_bool_or_bit_rprimitive(self.src.type):
10971092
self.is_borrowed = True
10981093

10991094
def sources(self) -> list[Value]:

mypyc/ir/rtypes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,10 @@ def is_bit_rprimitive(rtype: RType) -> bool:
582582
return isinstance(rtype, RPrimitive) and rtype.name == "bit"
583583

584584

585+
def is_bool_or_bit_rprimitive(rtype: RType) -> bool:
586+
return is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype)
587+
588+
585589
def is_object_rprimitive(rtype: RType) -> bool:
586590
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.object"
587591

mypyc/irbuild/ll_builder.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@
9393
dict_rprimitive,
9494
float_rprimitive,
9595
int_rprimitive,
96-
is_bit_rprimitive,
97-
is_bool_rprimitive,
96+
is_bool_or_bit_rprimitive,
9897
is_bytes_rprimitive,
9998
is_dict_rprimitive,
10099
is_fixed_width_rtype,
@@ -376,16 +375,12 @@ def coerce(
376375
):
377376
# Equivalent types
378377
return src
379-
elif (is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)) and is_tagged(
380-
target_type
381-
):
378+
elif is_bool_or_bit_rprimitive(src_type) and is_tagged(target_type):
382379
shifted = self.int_op(
383380
bool_rprimitive, src, Integer(1, bool_rprimitive), IntOp.LEFT_SHIFT
384381
)
385382
return self.add(Extend(shifted, target_type, signed=False))
386-
elif (
387-
is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)
388-
) and is_fixed_width_rtype(target_type):
383+
elif is_bool_or_bit_rprimitive(src_type) and is_fixed_width_rtype(target_type):
389384
return self.add(Extend(src, target_type, signed=False))
390385
elif isinstance(src, Integer) and is_float_rprimitive(target_type):
391386
if is_tagged(src_type):
@@ -1336,7 +1331,11 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13361331
return self.compare_strings(lreg, rreg, op, line)
13371332
if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ("==", "!="):
13381333
return self.compare_bytes(lreg, rreg, op, line)
1339-
if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in BOOL_BINARY_OPS:
1334+
if (
1335+
is_bool_or_bit_rprimitive(ltype)
1336+
and is_bool_or_bit_rprimitive(rtype)
1337+
and op in BOOL_BINARY_OPS
1338+
):
13401339
if op in ComparisonOp.signed_ops:
13411340
return self.bool_comparison_op(lreg, rreg, op, line)
13421341
else:
@@ -1350,7 +1349,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13501349
op_id = int_op_to_id[op]
13511350
else:
13521351
op_id = IntOp.DIV
1353-
if is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype):
1352+
if is_bool_or_bit_rprimitive(rtype):
13541353
rreg = self.coerce(rreg, ltype, line)
13551354
rtype = ltype
13561355
if is_fixed_width_rtype(rtype) or is_tagged(rtype):
@@ -1362,7 +1361,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13621361
elif op in ComparisonOp.signed_ops:
13631362
if is_int_rprimitive(rtype):
13641363
rreg = self.coerce_int_to_fixed_width(rreg, ltype, line)
1365-
elif is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype):
1364+
elif is_bool_or_bit_rprimitive(rtype):
13661365
rreg = self.coerce(rreg, ltype, line)
13671366
op_id = ComparisonOp.signed_ops[op]
13681367
if is_fixed_width_rtype(rreg.type):
@@ -1382,13 +1381,13 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13821381
)
13831382
if is_tagged(ltype):
13841383
return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line)
1385-
if is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype):
1384+
if is_bool_or_bit_rprimitive(ltype):
13861385
lreg = self.coerce(lreg, rtype, line)
13871386
return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line)
13881387
elif op in ComparisonOp.signed_ops:
13891388
if is_int_rprimitive(ltype):
13901389
lreg = self.coerce_int_to_fixed_width(lreg, rtype, line)
1391-
elif is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype):
1390+
elif is_bool_or_bit_rprimitive(ltype):
13921391
lreg = self.coerce(lreg, rtype, line)
13931392
op_id = ComparisonOp.signed_ops[op]
13941393
if isinstance(lreg, Integer):
@@ -1534,7 +1533,7 @@ def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Val
15341533
compare = self.binary_op(lhs_item, rhs_item, op, line)
15351534
# Cast to bool if necessary since most types uses comparison returning a object type
15361535
# See generic_ops.py for more information
1537-
if not (is_bool_rprimitive(compare.type) or is_bit_rprimitive(compare.type)):
1536+
if not is_bool_or_bit_rprimitive(compare.type):
15381537
compare = self.primitive_op(bool_op, [compare], line)
15391538
if i < len(lhs.type.types) - 1:
15401539
branch = Branch(compare, early_stop, check_blocks[i + 1], Branch.BOOL)
@@ -1553,7 +1552,7 @@ def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Val
15531552

15541553
def translate_instance_contains(self, inst: Value, item: Value, op: str, line: int) -> Value:
15551554
res = self.gen_method_call(inst, "__contains__", [item], None, line)
1556-
if not is_bool_rprimitive(res.type):
1555+
if not is_bool_or_bit_rprimitive(res.type):
15571556
res = self.primitive_op(bool_op, [res], line)
15581557
if op == "not in":
15591558
res = self.bool_bitwise_op(res, Integer(1, rtype=bool_rprimitive), "^", line)
@@ -1580,7 +1579,7 @@ def unary_not(self, value: Value, line: int) -> Value:
15801579

15811580
def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
15821581
typ = value.type
1583-
if is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
1582+
if is_bool_or_bit_rprimitive(typ):
15841583
if expr_op == "not":
15851584
return self.unary_not(value, line)
15861585
if expr_op == "+":
@@ -1738,7 +1737,7 @@ def bool_value(self, value: Value) -> Value:
17381737
17391738
The result type can be bit_rprimitive or bool_rprimitive.
17401739
"""
1741-
if is_bool_rprimitive(value.type) or is_bit_rprimitive(value.type):
1740+
if is_bool_or_bit_rprimitive(value.type):
17421741
result = value
17431742
elif is_runtime_subtype(value.type, int_rprimitive):
17441743
zero = Integer(0, short_int_rprimitive)

mypyc/test-data/irbuild-bool.test

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,54 @@ L0:
422422
r1 = extend r0: builtins.bool to builtins.int
423423
x = r1
424424
return x
425+
426+
[case testBitToBoolPromotion]
427+
def bitand(x: float, y: float, z: float) -> bool:
428+
b = (x == y) & (x == z)
429+
return b
430+
def bitor(x: float, y: float, z: float) -> bool:
431+
b = (x == y) | (x == z)
432+
return b
433+
def bitxor(x: float, y: float, z: float) -> bool:
434+
b = (x == y) ^ (x == z)
435+
return b
436+
def invert(x: float, y: float) -> bool:
437+
return not(x == y)
438+
[out]
439+
def bitand(x, y, z):
440+
x, y, z :: float
441+
r0, r1 :: bit
442+
r2, b :: bool
443+
L0:
444+
r0 = x == y
445+
r1 = x == z
446+
r2 = r0 & r1
447+
b = r2
448+
return b
449+
def bitor(x, y, z):
450+
x, y, z :: float
451+
r0, r1 :: bit
452+
r2, b :: bool
453+
L0:
454+
r0 = x == y
455+
r1 = x == z
456+
r2 = r0 | r1
457+
b = r2
458+
return b
459+
def bitxor(x, y, z):
460+
x, y, z :: float
461+
r0, r1 :: bit
462+
r2, b :: bool
463+
L0:
464+
r0 = x == y
465+
r1 = x == z
466+
r2 = r0 ^ r1
467+
b = r2
468+
return b
469+
def invert(x, y):
470+
x, y :: float
471+
r0, r1 :: bit
472+
L0:
473+
r0 = x == y
474+
r1 = r0 ^ 1
475+
return r1

0 commit comments

Comments
 (0)