Skip to content

Commit e662cef

Browse files
author
Luke Shaw
committed
Enable chaining of lazy expressions for logical operators
1 parent 982498e commit e662cef

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

src/blosc2/lazyexpr.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,7 +1919,6 @@ def update_expr(self, new_op): # noqa: C901
19191919
if hasattr(value2, "_where_args"):
19201920
value2 = value2.compute()
19211921

1922-
self._dtype = infer_dtype(op, value1, value2)
19231922
if not isinstance(value1, LazyExpr) and not isinstance(value2, LazyExpr):
19241923
# We converted some of the operands to NDArray (where() handling above)
19251924
new_operands = {"o0": value1, "o1": value2}
@@ -2677,8 +2676,8 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
26772676
_shape = new_expr.shape
26782677
if isinstance(new_expr, blosc2.LazyExpr):
26792678
# Restore the original expression and operands
2680-
new_expr.expression = _expression
2681-
new_expr.expression_tosave = expression
2679+
new_expr.expression = f"({_expression})" # forcibly add parenthesis
2680+
new_expr.expression_tosave = new_expr.expression
26822681
new_expr.operands = _operands
26832682
new_expr.operands_tosave = operands
26842683
else:

tests/ndarray/test_lazyexpr.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,10 @@ def test_dtype_infer(dtype1, dtype2, scalar):
12201220
np.testing.assert_allclose(res, nres)
12211221
assert res.dtype == nres.dtype
12221222

1223+
# Check dtype not changed by expression creation (bug fix)
1224+
assert a.dtype == dtype1
1225+
assert b.dtype == dtype2
1226+
12231227

12241228
@pytest.mark.parametrize(
12251229
"cfunc", ["np.int8", "np.int16", "np.int32", "np.int64", "np.float32", "np.float64"]
@@ -1330,3 +1334,21 @@ def test_missing_operator():
13301334
# Clean up
13311335
blosc2.remove_urlpath("a.b2nd")
13321336
blosc2.remove_urlpath("expr.b2nd")
1337+
1338+
1339+
# Test the chaining of multiple lazy expressions
1340+
def test_chain_expressions():
1341+
N = 1_000
1342+
dtype = "float64"
1343+
a = blosc2.linspace(0, 1, N * N, dtype=dtype, shape=(N, N))
1344+
b = blosc2.linspace(1, 2, N * N, dtype=dtype, shape=(N, N))
1345+
c = blosc2.linspace(0, 1, N, dtype=dtype, shape=(N,))
1346+
1347+
le1 = a**3 + blosc2.sin(a**2)
1348+
le2 = le1 < c
1349+
le3 = le2 & (b < 0)
1350+
1351+
le1_ = blosc2.lazyexpr("a ** 3 + sin(a ** 2)", {"a": a})
1352+
le2_ = blosc2.lazyexpr("(le1 < c)", {"le1": le1_, "c": c})
1353+
le3_ = blosc2.lazyexpr("(le2 & (b < 0))", {"le2": le2_, "b": b})
1354+
assert (le3_[:] == le3[:]).all()

0 commit comments

Comments
 (0)