Skip to content

Commit f547296

Browse files
committed
More tests on reductions on a single column
1 parent 1068de8 commit f547296

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

src/blosc2/lazyexpr.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1768,7 +1768,7 @@ def update_expr(self, new_op): # noqa: C901
17681768
new_operands = {}
17691769
# where() handling requires evaluating the expression prior to merge.
17701770
# This is different from reductions, where the expression is evaluated
1771-
# and returned an NumPy array (for usability convenience).
1771+
# and returned a NumPy array (for usability convenience).
17721772
# We do things like this to enable the fusion of operations like
17731773
# `a.where(0, 1).sum()`.
17741774
# Another possibility would have been to always evaluate where() and produce
@@ -1783,6 +1783,7 @@ def update_expr(self, new_op): # noqa: C901
17831783
# We converted some of the operands to NDArray (where() handling above)
17841784
new_operands = {"o0": value1, "o1": value2}
17851785
expression = f"(o0 {op} o1)"
1786+
return self._new_expr(expression, new_operands, guess=False, out=None, where=None)
17861787
elif isinstance(value1, LazyExpr) and isinstance(value2, LazyExpr):
17871788
# Expression fusion
17881789
# Fuse operands in expressions and detect duplicates
@@ -2113,6 +2114,14 @@ def prod(self, axis=None, dtype=None, keepdims=False, **kwargs):
21132114
return self.compute(_reduce_args=reduce_args, **kwargs)
21142115

21152116
def get_num_elements(self, axis, item):
2117+
if hasattr(self, "_where_args") and len(self._where_args) == 1:
2118+
# We have a where condition, so we need to count the number of elements
2119+
# fulfilling the condition
2120+
orig_where_args = self._where_args
2121+
self._where_args = {"_where_x": blosc2.ones(self.shape, dtype=np.int8)}
2122+
num_elements = self.sum(axis=axis, dtype=np.int64, item=item)
2123+
self._where_args = orig_where_args
2124+
return num_elements
21162125
if np.isscalar(axis):
21172126
axis = (axis,)
21182127
# Compute the number of elements in the array

tests/ndarray/test_lazyexpr_fields.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -502,22 +502,23 @@ def test_iter(shape, chunks, blocks):
502502
assert _i == shape[0] - 1
503503

504504

505-
def test_col_reduction():
505+
@pytest.mark.parametrize("reduce_op", ["sum", "mean", "min", "max", "std", "var"])
506+
def test_col_reduction(reduce_op):
506507
N = 1000
507508
rng = np.random.default_rng()
508509
it = ((-x + 1, x - 2, rng.normal()) for x in range(N))
509-
sa = blosc2.fromiter(
510-
it, dtype=[("A", "i4"), ("B", "f4"), ("C", "f8")], shape=(N,), urlpath="sa-1M.b2nd", mode="w"
511-
)
510+
sa = blosc2.fromiter(it, dtype=[("A", "i4"), ("B", "f4"), ("C", "f8")], shape=(N,), chunks=(N // 2,))
512511

513512
# The operations
513+
reduc = getattr(blosc2, reduce_op)
514514
C = sa.fields["C"]
515-
s = blosc2.sum(C[C > 0])
516-
s2 = blosc2.sum(C["C > 0"])
515+
s = reduc(C[C > 0])
516+
s2 = reduc(C["C > 0"]) # string version
517517

518518
# Check
519+
nreduc = getattr(np, reduce_op)
519520
nsa = sa[:]
520521
nC = nsa["C"]
521-
ns = np.sum(nC[nC > 0])
522+
ns = nreduc(nC[nC > 0])
522523
np.testing.assert_allclose(s, ns)
523524
np.testing.assert_allclose(s2, ns)

0 commit comments

Comments
 (0)