Skip to content

Commit 7eabaf0

Browse files
authored
Merge pull request #419 from Blosc/fixConditions
Make behaviour of compute consistent for slicing
2 parents 173c438 + 3db4ddb commit 7eabaf0

File tree

3 files changed

+116
-28
lines changed

3 files changed

+116
-28
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Imports
2+
3+
import numpy as np
4+
5+
import blosc2
6+
7+
N = 1000
8+
it = ((-x + 1, x - 2, 0.1 * x) for x in range(N))
9+
sa = blosc2.fromiter(
10+
it, dtype=[("A", "i4"), ("B", "f4"), ("C", "f8")], shape=(N,), urlpath="sa-1M.b2nd", mode="w"
11+
)
12+
expr = sa["(A < B)"]
13+
A = sa["A"][:]
14+
B = sa["B"][:]
15+
C = sa["C"][:]
16+
temp = sa[:]
17+
indices = A < B
18+
idx = np.argmax(indices)
19+
20+
# One might think that expr[:10] gives the first 10 elements of the evaluated expression, but this is not the case.
21+
# It actually computes the expression on the first 10 elements of the operands; since for some elements the condition
22+
# is False, the result will be shorter than 10 elements.
23+
# Returns less than 10 elements in general
24+
sliced = expr.compute(slice(0, 10))
25+
gotitem = expr[:10]
26+
np.testing.assert_array_equal(sliced[:], gotitem)
27+
np.testing.assert_array_equal(gotitem, temp[:10][indices[:10]]) # Equivalent syntax
28+
# Actually this makes sense since one can understand this as a request to compute on a portion of operands.
29+
# If one desires a portion of the result, one should compute the whole expression and then slice it.
30+
31+
# Get first element for which condition is true
32+
sliced = expr.compute(idx)
33+
gotitem = expr[idx]
34+
# Arrays of one element
35+
np.testing.assert_array_equal(sliced[()], gotitem)
36+
np.testing.assert_array_equal(gotitem, temp[idx])
37+
38+
# Should return void arrays here.
39+
sliced = expr.compute(0)
40+
gotitem = expr[0]
41+
np.testing.assert_array_equal(sliced[()], gotitem)
42+
np.testing.assert_array_equal(gotitem, temp[0])

src/blosc2/lazyexpr.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,9 @@ def compute(self, item: slice | list[slice] | None = None, **kwargs: Any) -> blo
278278
Parameters
279279
----------
280280
item: slice, list of slices, optional
281-
If not None, only the chunks that intersect with the slices
282-
in items will be evaluated.
281+
If provided, item is used to slice the operands *prior* to computation; not to retrieve specified slices of
282+
the evaluated result. This difference between slicing operands and slicing the final expression
283+
is important when reductions or a where clause are used in the expression.
283284
284285
kwargs: Any, optional
285286
Keyword arguments that are supported by the :func:`empty` constructor.
@@ -328,7 +329,9 @@ def __getitem__(self, item: int | slice | Sequence[slice]) -> blosc2.NDArray:
328329
Parameters
329330
----------
330331
item: int, slice or sequence of slices
331-
The slice(s) to be retrieved. Note that step parameter is not yet honored.
332+
If provided, item is used to slice the operands *prior* to computation; not to retrieve specified slices of
333+
the evaluated result. This difference between slicing operands and slicing the final expression
334+
is important when reductions or a where clause are used in the expression.
332335
333336
Returns
334337
-------
@@ -1378,7 +1381,8 @@ def slices_eval( # noqa: C901
13781381
for i, (c, s) in enumerate(zip(coords, chunks, strict=True))
13791382
)
13801383
# Check whether current slice_ intersects with _slice
1381-
if _slice is not None and _slice != ():
1384+
checker = _slice.item() if hasattr(_slice, "item") else _slice # can't use != when _slice is np.int
1385+
if checker is not None and checker != ():
13821386
# Ensure that _slice is of type slice
13831387
key = ndindex.ndindex(_slice).expand(shape).raw
13841388
_slice = tuple(k if isinstance(k, slice) else slice(k, k + 1, None) for k in key)
@@ -1508,19 +1512,7 @@ def slices_eval( # noqa: C901
15081512
else:
15091513
raise ValueError("The where condition must be a tuple with one or two elements")
15101514

1511-
if orig_slice is not None:
1512-
if isinstance(out, np.ndarray):
1513-
out = out[orig_slice]
1514-
if _order is not None:
1515-
indices_ = indices_[orig_slice]
1516-
elif isinstance(out, blosc2.NDArray):
1517-
# It *seems* better to choose an automatic chunks and blocks for the output array
1518-
# out = out.slice(orig_slice, chunks=out.chunks, blocks=out.blocks)
1519-
out = out.slice(orig_slice)
1520-
else:
1521-
raise ValueError("The output array is not a NumPy array or a NDArray")
1522-
1523-
if where is not None and len(where) < 2:
1515+
if where is not None and len(where) < 2: # Don't need to take orig_slice since filled up from 0 index
15241516
if _order is not None:
15251517
# argsort the result following _order
15261518
new_order = np.argsort(out[:lenout])
@@ -1532,6 +1524,19 @@ def slices_eval( # noqa: C901
15321524
else:
15331525
out.resize((lenout,))
15341526

1527+
else: # Need to take orig_slice since filled up array according to slice_ for each chunk
1528+
if orig_slice is not None:
1529+
if isinstance(out, np.ndarray):
1530+
out = out[orig_slice]
1531+
if _order is not None:
1532+
indices_ = indices_[orig_slice]
1533+
elif isinstance(out, blosc2.NDArray):
1534+
# It *seems* better to choose an automatic chunks and blocks for the output array
1535+
# out = out.slice(orig_slice, chunks=out.chunks, blocks=out.blocks)
1536+
out = out.slice(orig_slice)
1537+
else:
1538+
raise ValueError("The output array is not a NumPy array or a NDArray")
1539+
15351540
return out
15361541

15371542

@@ -1827,7 +1832,8 @@ def chunked_eval( # noqa: C901
18271832
operands: dict
18281833
A dictionary containing the operands for the expression.
18291834
item: int, slice or sequence of slices, optional
1830-
The slice(s) to be retrieved. Note that step parameter is not honored yet.
1835+
The slice(s) of the operands to be used in computation. Note that step parameter is not honored yet.
1836+
Item is used to slice the operands PRIOR to computation.
18311837
kwargs: Any, optional
18321838
Additional keyword arguments supported by the :func:`empty` constructor. In addition,
18331839
the following keyword arguments are supported:

tests/ndarray/test_lazyexpr_fields.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,17 +279,19 @@ def test_where_one_param(array_fixture):
279279
res = np.sort(res)
280280
nres = np.sort(nres)
281281
np.testing.assert_allclose(res[:], nres)
282+
282283
# Test with getitem
283284
sl = slice(100)
284285
res = expr.where(a1)[sl]
286+
nres = na1[sl][ne_evaluate("na1**2 + na2**2 > 2 * na1 * na2 + 1")[sl]]
285287
if len(a1.shape) == 1 or a1.chunks == a1.shape:
286288
# TODO: fix this, as it seems that is not working well for numexpr?
287289
if blosc2.IS_WASM:
288290
return
289-
np.testing.assert_allclose(res, nres[sl])
291+
np.testing.assert_allclose(res, nres)
290292
else:
291293
# In this case, we cannot compare results, only the length
292-
assert len(res) == len(nres[sl])
294+
assert len(res) == len(nres)
293295

294296

295297
# Test where indirectly via a condition in getitem in a NDArray
@@ -330,25 +332,26 @@ def test_where_getitem(array_fixture):
330332
# Test with partial slice
331333
sl = slice(100)
332334
res = sa1[a1**2 + a2**2 > 2 * a1 * a2 + 1][sl]
335+
nres = nsa1[sl][ne_evaluate("na1**2 + na2**2 > 2 * na1 * na2 + 1")[sl]]
333336
if len(a1.shape) == 1 or a1.chunks == a1.shape:
334337
# TODO: fix this, as it seems that is not working well for numexpr?
335338
if blosc2.IS_WASM:
336339
return
337-
np.testing.assert_allclose(res["a"], nres[sl]["a"])
338-
np.testing.assert_allclose(res["b"], nres[sl]["b"])
340+
np.testing.assert_allclose(res["a"], nres["a"])
341+
np.testing.assert_allclose(res["b"], nres["b"])
339342
else:
340343
# In this case, we cannot compare results, only the length
341-
assert len(res["a"]) == len(nres[sl]["a"])
342-
assert len(res["b"]) == len(nres[sl]["b"])
344+
assert len(res["a"]) == len(nres["a"])
345+
assert len(res["b"]) == len(nres["b"])
343346
# string version
344347
res = sa1["a**2 + b**2 > 2 * a * b + 1"][sl]
345348
if len(a1.shape) == 1 or a1.chunks == a1.shape:
346-
np.testing.assert_allclose(res["a"], nres[sl]["a"])
347-
np.testing.assert_allclose(res["b"], nres[sl]["b"])
349+
np.testing.assert_allclose(res["a"], nres["a"])
350+
np.testing.assert_allclose(res["b"], nres["b"])
348351
else:
349352
# We cannot compare the results here, other than the length
350-
assert len(res["a"]) == len(nres[sl]["a"])
351-
assert len(res["b"]) == len(nres[sl]["b"])
353+
assert len(res["a"]) == len(nres["a"])
354+
assert len(res["b"]) == len(nres["b"])
352355

353356

354357
# Test where indirectly via a condition in getitem in a NDField
@@ -631,3 +634,40 @@ def test_col_reduction(reduce_op):
631634
ns = nreduc(nC[nC > 0])
632635
np.testing.assert_allclose(s, ns)
633636
np.testing.assert_allclose(s2, ns)
637+
638+
639+
def test_fields_indexing():
640+
N = 1000
641+
it = ((-x + 1, x - 2, 0.1 * x) for x in range(N))
642+
sa = blosc2.fromiter(
643+
it, dtype=[("A", "i4"), ("B", "f4"), ("C", "f8")], shape=(N,), urlpath="sa-1M.b2nd", mode="w"
644+
)
645+
expr = sa["(A < B)"]
646+
A = sa["A"][:]
647+
B = sa["B"][:]
648+
C = sa["C"][:]
649+
temp = sa[:]
650+
indices = A < B
651+
idx = np.argmax(indices)
652+
653+
# Returns less than 10 elements in general
654+
sliced = expr.compute(slice(0, 10))
655+
gotitem = expr[:10]
656+
np.testing.assert_array_equal(sliced[:], gotitem)
657+
np.testing.assert_array_equal(gotitem, temp[:10][indices[:10]])
658+
# Actually this makes sense since one can understand this as a request to compute on a portion of operands.
659+
# If one desires a portion of the result, one should compute the whole expression and then slice it.
660+
# For a general slice it is quite difficult to simply stop when the desired slice has been obtained. Or
661+
# to try to optimise chunk computation order.
662+
663+
# Get first true element
664+
sliced = expr.compute(idx)
665+
gotitem = expr[idx]
666+
np.testing.assert_array_equal(sliced[()], gotitem)
667+
np.testing.assert_array_equal(gotitem, temp[idx])
668+
669+
# Should return void arrays here.
670+
sliced = expr.compute(0) # typically gives array of zeros
671+
gotitem = expr[0] # gives an error
672+
np.testing.assert_array_equal(sliced[()], gotitem)
673+
np.testing.assert_array_equal(gotitem, temp[0])

0 commit comments

Comments
 (0)