Skip to content

Commit 1f0f33f

Browse files
authored
Merge pull request #411 from Blosc/fixChaining
Allow use of where for string expressions and integer operands
2 parents cfcd6c7 + 1cac8d4 commit 1f0f33f

File tree

3 files changed

+64
-8
lines changed

3 files changed

+64
-8
lines changed

src/blosc2/lazyexpr.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2091,6 +2091,8 @@ def dtype(self):
20912091

20922092
operands = {
20932093
key: np.ones(np.ones(len(value.shape), dtype=int), dtype=value.dtype)
2094+
if hasattr(value, "shape")
2095+
else value
20942096
for key, value in self.operands.items()
20952097
}
20962098

@@ -2784,6 +2786,20 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
27842786
expression_, operands_ = conserve_functions(
27852787
_expression, _operands, new_expr.operands | local_vars
27862788
)
2789+
# if new_expr has where_args, must have come from where(...) - or possibly where(where(..
2790+
# since 5*where, where + ... are evaluated eagerly
2791+
if hasattr(new_expr, "_where_args"):
2792+
st = expression_.find("where(") + len(
2793+
"where("
2794+
) # expr always begins where( - should have st = 6 always
2795+
finalexpr = ""
2796+
counter = 0
2797+
for char in expression_[st:]: # get rid of external where(...)
2798+
finalexpr += char
2799+
counter += 1 * (char == "(") - 1 * (char == ")")
2800+
if counter == 0 and char == ",":
2801+
break
2802+
expression_ = finalexpr[:-1] # remove trailing comma
27872803
new_expr.expression = f"({expression_})" # force parenthesis
27882804
new_expr.expression_tosave = expression
27892805
new_expr.operands = operands_

tests/ndarray/test_lazyexpr.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,14 +1364,13 @@ def test_chain_expressions():
13641364
le4_ = blosc2.lazyexpr("(le2 & le3)", {"le2": le2_, "le3": le3_})
13651365
assert (le4_[:] == le4[:]).all()
13661366

1367-
# TODO: Eventually this test should pass
1368-
# expr1 = blosc2.lazyexpr("arange(N) + b")
1369-
# expr2 = blosc2.lazyexpr("a * b + 1")
1370-
# expr = blosc2.lazyexpr("expr1 - expr2")
1371-
# expr_final = blosc2.lazyexpr("expr * expr")
1372-
# nres = (expr * expr)[:]
1373-
# res = expr_final.compute()
1374-
# np.testing.assert_allclose(res[:], nres)
1367+
expr1 = blosc2.lazyexpr("arange(N) + b")
1368+
expr2 = blosc2.lazyexpr("a * b + 1")
1369+
expr = blosc2.lazyexpr("expr1 - expr2")
1370+
expr_final = blosc2.lazyexpr("expr * expr")
1371+
nres = (expr * expr)[:]
1372+
res = expr_final.compute()
1373+
np.testing.assert_allclose(res[:], nres)
13751374

13761375

13771376
# Test the chaining of multiple persistent lazy expressions

tests/ndarray/test_lazyexpr_fields.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,47 @@ def test_where(array_fixture):
226226
np.testing.assert_allclose(res, nres[sl])
227227

228228

229+
# Test expressions with where() and string comps
230+
def test_lazy_where(array_fixture):
231+
sa1, sa2, nsa1, nsa2, a1, a2, a3, a4, na1, na2, na3, na4 = array_fixture
232+
233+
# Test 1: where
234+
# Test with string expression
235+
expr = blosc2.lazyexpr("where((a1 ** 2 + a2 ** 2) > (2 * a1 * a2 + 1), 0, a1)")
236+
# Test with eval
237+
res = expr.compute()
238+
nres = ne_evaluate("where(na1**2 + na2**2 > 2 * na1 * na2 + 1, 0, na1)")
239+
np.testing.assert_allclose(res[:], nres)
240+
# Test with getitem
241+
sl = slice(100)
242+
res = expr[sl]
243+
np.testing.assert_allclose(res, nres[sl])
244+
245+
# Test 2: sum of wheres
246+
# Test with string expression
247+
expr = blosc2.lazyexpr("where(a1 < 0, 10, a1) + where(a2 < 0, 3, a2)")
248+
# Test with eval
249+
res = expr.compute()
250+
nres = ne_evaluate("where(na1 < 0, 10, na1) + where(na2 < 0, 3, na2)")
251+
np.testing.assert_allclose(res[:], nres)
252+
253+
# Test 3: nested wheres
254+
# Test with string expression
255+
expr = blosc2.lazyexpr("where(where(a2 < 0, 3, a2) > 3, 10, a1)")
256+
# Test with eval
257+
res = expr.compute()
258+
nres = ne_evaluate("where(where(na2 < 0, 3, na2) > 3, 10, na1)")
259+
np.testing.assert_allclose(res[:], nres)
260+
261+
# Test 4: multiplied wheres
262+
# Test with string expression
263+
expr = blosc2.lazyexpr("1 * where(a2 < 0, 3, a2)")
264+
# Test with eval
265+
res = expr.compute()
266+
nres = ne_evaluate("1 * where(na2 < 0, 3, na2)")
267+
np.testing.assert_allclose(res[:], nres)
268+
269+
229270
# Test where with one parameter
230271
def test_where_one_param(array_fixture):
231272
sa1, sa2, nsa1, nsa2, a1, a2, a3, a4, na1, na2, na3, na4 = array_fixture

0 commit comments

Comments
 (0)