Skip to content

Commit 1e9dc5c

Browse files
Merge pull request #391 from lshaw8317/fixChainingLazy
Enable chaining of lazy expressions for logical operators
2 parents 7324451 + e662cef commit 1e9dc5c

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

src/blosc2/lazyexpr.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,7 +1920,6 @@ def update_expr(self, new_op): # noqa: C901
19201920
if hasattr(value2, "_where_args"):
19211921
value2 = value2.compute()
19221922

1923-
self._dtype = infer_dtype(op, value1, value2)
19241923
if not isinstance(value1, LazyExpr) and not isinstance(value2, LazyExpr):
19251924
# We converted some of the operands to NDArray (where() handling above)
19261925
new_operands = {"o0": value1, "o1": value2}
@@ -2678,8 +2677,8 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
26782677
_shape = new_expr.shape
26792678
if isinstance(new_expr, blosc2.LazyExpr):
26802679
# Restore the original expression and operands
2681-
new_expr.expression = _expression
2682-
new_expr.expression_tosave = expression
2680+
new_expr.expression = f"({_expression})" # forcibly add parenthesis
2681+
new_expr.expression_tosave = new_expr.expression
26832682
new_expr.operands = _operands
26842683
new_expr.operands_tosave = operands
26852684
else:

tests/ndarray/test_lazyexpr.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,10 @@ def test_dtype_infer(dtype1, dtype2, scalar):
12201220
np.testing.assert_allclose(res, nres)
12211221
assert res.dtype == nres.dtype
12221222

1223+
# Check dtype not changed by expression creation (bug fix)
1224+
assert a.dtype == dtype1
1225+
assert b.dtype == dtype2
1226+
12231227

12241228
@pytest.mark.parametrize(
12251229
"cfunc", ["np.int8", "np.int16", "np.int32", "np.int64", "np.float32", "np.float64"]
@@ -1330,3 +1334,21 @@ def test_missing_operator():
13301334
# Clean up
13311335
blosc2.remove_urlpath("a.b2nd")
13321336
blosc2.remove_urlpath("expr.b2nd")
1337+
1338+
1339+
# Test the chaining of multiple lazy expressions
1340+
def test_chain_expressions():
1341+
N = 1_000
1342+
dtype = "float64"
1343+
a = blosc2.linspace(0, 1, N * N, dtype=dtype, shape=(N, N))
1344+
b = blosc2.linspace(1, 2, N * N, dtype=dtype, shape=(N, N))
1345+
c = blosc2.linspace(0, 1, N, dtype=dtype, shape=(N,))
1346+
1347+
le1 = a**3 + blosc2.sin(a**2)
1348+
le2 = le1 < c
1349+
le3 = le2 & (b < 0)
1350+
1351+
le1_ = blosc2.lazyexpr("a ** 3 + sin(a ** 2)", {"a": a})
1352+
le2_ = blosc2.lazyexpr("(le1 < c)", {"le1": le1_, "c": c})
1353+
le3_ = blosc2.lazyexpr("(le2 & (b < 0))", {"le2": le2_, "b": b})
1354+
assert (le3_[:] == le3[:]).all()

0 commit comments

Comments
 (0)