Skip to content

Commit f1487d4

Browse files
committed
Fix boolean add/mult
1 parent f114e9f commit f1487d4

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

numexpr/expressions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,9 @@ class OpNode(ExpressionNode):
531531
def __init__(self, opcode=None, args=None, kind=None):
532532
if (kind is None) and (args is not None):
533533
kind = commonKind(args)
534+
if kind=='bool': # handle bool*bool and bool+bool cases
535+
opcode = 'and' if opcode=='mul' else opcode
536+
opcode = 'or' if opcode=='add' else opcode
534537
ExpressionNode.__init__(self, value=opcode, kind=kind, children=args)
535538

536539

numexpr/tests/test_numexpr.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,12 @@ def test_maximum_minimum(self):
490490
assert_array_equal(evaluate("maximum(x,y)"), maximum(x,y))
491491
assert_array_equal(evaluate("minimum(x,y)"), minimum(x,y))
492492

493+
def test_addmult_booleans(self):
494+
x = np.asarray([0, 1, 0, 0, 1], dtype=bool)
495+
y = x[::-1]
496+
assert_array_equal(evaluate("x * y"), x * y)
497+
assert_array_equal(evaluate("x + y"), x + y)
498+
493499
def test_sign_round(self):
494500
for dtype in [float, double, np.int32, np.int64, complex]:
495501
x = arange(10, dtype=dtype)

0 commit comments

Comments
 (0)