Skip to content

Commit db28770

Browse files
authored
Merge pull request #446 from Blosc/improve_slice_eval
Improve slice eval
2 parents 5e83dd7 + db2d38f commit db28770

File tree

8 files changed

+180
-57
lines changed

8 files changed

+180
-57
lines changed

.github/workflows/wasm.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
run: pip install cibuildwheel
4747

4848
- name: Build wheels
49-
# Testing is automaticall made by cibuildwheel
49+
# Testing is automatically made by cibuildwheel
5050
run: cibuildwheel --platform pyodide
5151

5252
- name: Upload wheels

bench/ndarray/slice-expr-step.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#######################################################################
2+
# Copyright (c) 2019-present, Blosc Development Team <[email protected]>
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under a BSD-style license (found in the
6+
# LICENSE file in the root directory of this source tree)
7+
#######################################################################
8+
9+
# Benchmark for computing a slice with non-unit steps of a expression in a ND array.
10+
11+
import blosc2
12+
import numpy as np
13+
import matplotlib.pyplot as plt
14+
from memory_profiler import profile, memory_usage
15+
16+
N = 50_000
17+
LARGE_SLICE = False
18+
ndim = 2
19+
shape = (N, ) * ndim
20+
a = blosc2.linspace(start=0, stop=np.prod(shape), num=np.prod(shape), dtype=np.float64, shape=shape)
21+
_slice = (slice(0, N, 2),) if LARGE_SLICE else (slice(0, N, N//4),)
22+
expr = 2 * a ** 2
23+
24+
@profile
25+
def _slice_():
26+
res1 = expr.slice(_slice)
27+
print(f'Result of slice occupies {res1.schunk.cbytes / 1024**2:.2f} MiB')
28+
return res1
29+
30+
@profile
31+
def _gitem():
32+
res2 = expr[_slice]
33+
print(f'Result of _getitem_ occupies {np.prod(res2.shape) * res2.itemsize / 1024**2:.2f} MiB')
34+
return res2
35+
36+
interval = 0.001
37+
offset = 0
38+
for f in [_slice_, _gitem]:
39+
mem = memory_usage((f,), interval=interval)
40+
times = offset + interval * np.arange(len(mem))
41+
offset = times[-1]
42+
plt.plot(times, mem)
43+
44+
plt.xlabel('Time (s)')
45+
plt.ylabel('Memory usage (MiB)')
46+
lab = 'LARGE' if LARGE_SLICE else 'SMALL'
47+
plt.title(f'{lab} slice w/steps, Linux Blosc2 {blosc2.__version__}')
48+
plt.legend([f'expr.slice({_slice}', f'expr[{_slice}]'])
49+
plt.savefig(f'sliceexpr_{lab}_Blosc{blosc2.__version__.replace('.','_')}.png', format="png")

src/blosc2/lazyexpr.py

Lines changed: 62 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,26 +1412,19 @@ def slices_eval( # noqa: C901
14121412
# keep orig_slice
14131413
_slice = _slice.raw
14141414
orig_slice = _slice
1415-
full_slice = () # by default the full_slice is the whole array
1416-
final_slice = () # by default the final_slice is the whole array
14171415

14181416
# Compute the shape and chunks of the output array, including broadcasting
14191417
shape = compute_broadcast_shape(operands.values())
14201418
if out is None:
14211419
if _slice != ():
14221420
# Check whether _slice contains an integer, or any step that are not None or 1
1423-
if any(
1424-
(isinstance(s, int)) or (isinstance(s, slice) and s.step not in (None, 1)) for s in _slice
1425-
):
1421+
if any((isinstance(s, int)) for s in _slice):
14261422
need_final_slice = True
1427-
_slice = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in _slice)
1428-
full_slice = tuple(
1429-
slice(s.start or 0, s.stop or shape[i], 1) for i, s in enumerate(_slice)
1430-
) # get rid of non-unit steps
1423+
_slice = tuple(slice(i, i + 1, 1) if isinstance(i, int) else i for i in _slice)
14311424
# shape_slice in general not equal to final shape:
1432-
# dummy dims (due to ints) or non-unit steps will be dealt with by taking final_slice
1433-
shape_slice = ndindex.ndindex(full_slice).newshape(shape)
1434-
final_slice = ndindex.ndindex(orig_slice).as_subindex(full_slice).raw
1425+
# dummy dims (due to ints) will be dealt with by taking final_slice
1426+
shape_slice = ndindex.ndindex(_slice).newshape(shape)
1427+
mask_slice = np.bool([isinstance(i, int) for i in orig_slice])
14351428
else:
14361429
# # out should always have shape of full array
14371430
# if shape is not None and shape != out.shape:
@@ -1476,23 +1469,21 @@ def slices_eval( # noqa: C901
14761469
for nchunk, chunk_slice in enumerate(intersecting_chunks):
14771470
# get intersection of chunk and target
14781471
if _slice != ():
1479-
cslice = tuple(
1480-
slice(max(s1.start, s2.start), min(s1.stop, s2.stop))
1481-
for s1, s2 in zip(chunk_slice.raw, _slice, strict=True)
1482-
)
1472+
cslice = step_handler(chunk_slice.raw, _slice)
14831473
else:
14841474
cslice = chunk_slice.raw
14851475

14861476
cslice_shape = tuple(s.stop - s.start for s in cslice)
14871477
len_chunk = math.prod(cslice_shape)
14881478
# get local index of part of out that is to be updated
14891479
cslice_subidx = (
1490-
ndindex.ndindex(cslice).as_subindex(full_slice).raw
1491-
) # in the case full_slice=(), just gives cslice
1480+
ndindex.ndindex(cslice).as_subindex(_slice).raw
1481+
) # in the case _slice=(), just gives cslice
14921482

14931483
# Get the starts and stops for the slice
14941484
starts = [s.start if s.start is not None else 0 for s in cslice]
14951485
stops = [s.stop if s.stop is not None else sh for s, sh in zip(cslice, cslice_shape, strict=True)]
1486+
unit_steps = np.all([s.step == 1 for s in cslice])
14961487

14971488
# Get the slice of each operand
14981489
for key, value in operands.items():
@@ -1512,6 +1503,7 @@ def slices_eval( # noqa: C901
15121503
key in chunk_operands
15131504
and cslice_shape == chunk_operands[key].shape
15141505
and isinstance(value, blosc2.NDArray)
1506+
and unit_steps
15151507
):
15161508
value.get_slice_numpy(chunk_operands[key], (starts, stops))
15171509
continue
@@ -1565,6 +1557,9 @@ def slices_eval( # noqa: C901
15651557
result = x[result]
15661558
else:
15671559
raise ValueError("The where condition must be a tuple with one or two elements")
1560+
# Enforce contiguity of result (necessary to fill the out array)
1561+
# but avoid copy if already contiguous
1562+
result = np.require(result, requirements="C")
15681563

15691564
if out is None:
15701565
shape_ = shape_slice if shape_slice is not None else shape
@@ -1622,15 +1617,13 @@ def slices_eval( # noqa: C901
16221617
out.resize((lenout,))
16231618

16241619
else: # Need to take final_slice since filled up array according to slice_ for each chunk
1625-
if final_slice != ():
1620+
if need_final_slice: # only called if out was None
16261621
if isinstance(out, np.ndarray):
1627-
if need_final_slice: # only called if out was None
1628-
out = out[final_slice]
1622+
out = np.squeeze(out, np.where(mask_slice)[0])
16291623
elif isinstance(out, blosc2.NDArray):
16301624
# It *seems* better to choose an automatic chunks and blocks for the output array
16311625
# out = out.slice(_slice, chunks=out.chunks, blocks=out.blocks)
1632-
if need_final_slice: # only called if out was None
1633-
out = out.slice(final_slice)
1626+
out = out.squeeze(mask_slice)
16341627
else:
16351628
raise ValueError("The output array is not a NumPy array or a NDArray")
16361629

@@ -1743,17 +1736,22 @@ def infer_reduction_dtype(dtype, operation):
17431736
raise ValueError(f"Unsupported operation: {operation}")
17441737

17451738

1746-
def step_handler(s1start, s2start, s1stop, s2stop, s2step):
1747-
# assume s1step = 1
1748-
newstart = max(s1start, s2start)
1749-
newstop = min(s1stop, s2stop)
1750-
rem = (newstart - s2start) % s2step
1751-
if rem != 0: # only pass through here if s2step is not 1
1752-
newstart += s2step - rem
1753-
# true_stop = start + n*step + 1 -> stop = start + n * step + 1 + residual
1754-
# so n = (stop - start - 1) // step
1755-
newstop = newstart + (newstop - newstart - 1) // s2step * s2step + 1
1756-
return slice(newstart, newstop, s2step)
1739+
def step_handler(cslice, _slice):
1740+
out = ()
1741+
for s1, s2 in zip(cslice, _slice, strict=True):
1742+
s1start, s1stop = s1.start, s1.stop
1743+
s2start, s2stop, s2step = s2.start, s2.stop, s2.step
1744+
# assume s1step = 1
1745+
newstart = max(s1start, s2start)
1746+
newstop = min(s1stop, s2stop)
1747+
rem = (newstart - s2start) % s2step
1748+
if rem != 0: # only pass through here if s2step is not 1
1749+
newstart += s2step - rem
1750+
# true_stop = start + n*step + 1 -> stop = start + n * step + 1 + residual
1751+
# so n = (stop - start - 1) // step
1752+
newstop = newstart + (newstop - newstart - 1) // s2step * s2step + 1
1753+
out += (slice(newstart, newstop, s2step),)
1754+
return out
17571755

17581756

17591757
def reduce_slices( # noqa: C901
@@ -1807,21 +1805,27 @@ def reduce_slices( # noqa: C901
18071805

18081806
_slice = _slice.raw
18091807
shape_slice = shape
1810-
full_slice = () # by default the full_slice is the whole array
1808+
mask_slice = np.bool([isinstance(i, int) for i in _slice])
18111809
if out is None and _slice != ():
1810+
_slice = tuple(slice(i, i + 1, 1) if isinstance(i, int) else i for i in _slice)
18121811
shape_slice = ndindex.ndindex(_slice).newshape(shape)
1813-
full_slice = _slice
1812+
# shape_slice in general not equal to final shape:
1813+
# dummy dims (due to ints) will be dealt with by taking final_slice
18141814

18151815
# after slicing, we reduce to calculate shape of output
18161816
if axis is None:
18171817
axis = tuple(range(len(shape_slice)))
18181818
elif not isinstance(axis, tuple):
18191819
axis = (axis,)
1820-
axis = tuple(a if a >= 0 else a + len(shape_slice) for a in axis)
1820+
axis = np.array([a if a >= 0 else a + len(shape_slice) for a in axis])
1821+
if np.any(mask_slice):
1822+
axis = tuple(axis + np.cumsum(mask_slice)[axis]) # axis now refers to new shape with dummy dims
1823+
reduce_args["axis"] = axis
18211824
if keepdims:
18221825
reduced_shape = tuple(1 if i in axis else s for i, s in enumerate(shape_slice))
18231826
else:
18241827
reduced_shape = tuple(s for i, s in enumerate(shape_slice) if i not in axis)
1828+
mask_slice = mask_slice[[i for i in range(len(mask_slice)) if i not in axis]]
18251829

18261830
if out is not None and reduced_shape != out.shape:
18271831
raise ValueError("Provided output shape does not match the reduced shape.")
@@ -1876,13 +1880,10 @@ def reduce_slices( # noqa: C901
18761880
# Check whether current cslice intersects with _slice
18771881
if cslice != () and _slice != ():
18781882
# get intersection of chunk and target
1879-
cslice = tuple(
1880-
step_handler(s1.start, s2.start, s1.stop, s2.stop, s2.step)
1881-
for s1, s2 in zip(cslice, _slice, strict=True)
1882-
)
1883+
cslice = step_handler(cslice, _slice)
18831884
chunks_ = tuple(s.stop - s.start for s in cslice)
1884-
1885-
if _slice == () and fast_path:
1885+
unit_steps = np.all([s.step == 1 for s in cslice])
1886+
if _slice == () and fast_path and unit_steps:
18861887
# Fast path
18871888
full_chunk = chunks_ == chunks
18881889
fill_chunk_operands(
@@ -1910,15 +1911,14 @@ def reduce_slices( # noqa: C901
19101911
key in chunk_operands
19111912
and chunks_ == chunk_operands[key].shape
19121913
and isinstance(value, blosc2.NDArray)
1914+
and unit_steps
19131915
):
19141916
value.get_slice_numpy(chunk_operands[key], (starts, stops))
19151917
continue
19161918
chunk_operands[key] = value[cslice]
19171919

19181920
# get local index of part of out that is to be updated
1919-
cslice_subidx = (
1920-
ndindex.ndindex(cslice).as_subindex(full_slice).raw
1921-
) # if full_slice is (), just gives cslice
1921+
cslice_subidx = ndindex.ndindex(cslice).as_subindex(_slice).raw # if _slice is (), just gives cslice
19221922
if keepdims:
19231923
reduced_slice = tuple(slice(None) if i in axis else sl for i, sl in enumerate(cslice_subidx))
19241924
else:
@@ -1938,8 +1938,8 @@ def reduce_slices( # noqa: C901
19381938

19391939
if where is None:
19401940
if expression == "o0":
1941-
# We don't have an actual expression, so avoid a copy
1942-
result = chunk_operands["o0"]
1941+
# We don't have an actual expression, so avoid a copy except to make contiguous
1942+
result = np.require(chunk_operands["o0"], requirements="C")
19431943
else:
19441944
result = ne_evaluate(expression, chunk_operands, **ne_args)
19451945
else:
@@ -1997,6 +1997,9 @@ def reduce_slices( # noqa: C901
19971997
dtype = np.float64
19981998
out = convert_none_out(dtype, reduce_op, reduced_shape)
19991999

2000+
final_mask = tuple(np.where(mask_slice)[0])
2001+
if np.any(mask_slice): # remove dummy dims
2002+
out = np.squeeze(out, axis=final_mask)
20002003
# Check if the output array needs to be converted into a blosc2.NDArray
20012004
if kwargs != {} and not np.isscalar(out):
20022005
out = blosc2.asarray(out, **kwargs)
@@ -2089,7 +2092,9 @@ def chunked_eval( # noqa: C901
20892092
# The fast path is possible under a few conditions
20902093
if getitem and (where is None or len(where) == 2) and not callable(expression):
20912094
# Compute the size of operands for the fast path
2092-
shape_operands = item.newshape(shape) # shape of slice
2095+
unit_steps = np.all([s.step == 1 for s in item.raw if isinstance(s, slice)])
2096+
# shape of slice, if non-unit steps have to decompress full array into memory
2097+
shape_operands = item.newshape(shape) if unit_steps else shape
20932098
_dtype = kwargs.get("dtype", np.float64)
20942099
size_operands = math.prod(shape_operands) * len(operands) * _dtype.itemsize
20952100
# Only take the fast path if the size of operands is relatively small
@@ -3118,6 +3123,14 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
31183123
new_expr.expression_tosave = expression
31193124
new_expr.operands = operands_
31203125
new_expr.operands_tosave = operands
3126+
elif isinstance(new_expr, blosc2.NDArray) and len(operands) == 1:
3127+
# passed either "a" or possible "a[:10]"
3128+
expression_, operands_ = conserve_functions(
3129+
_expression, _operands, {"o0": list(operands.values())[0]} | local_vars
3130+
)
3131+
new_expr = cls(None)
3132+
new_expr.expression = expression_
3133+
new_expr.operands = operands_
31213134
else:
31223135
# An immediate evaluation happened (e.g. all operands are numpy arrays)
31233136
new_expr = cls(None)

tests/ndarray/test_concatenate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,5 @@ def test_stack(shape, dtype, axis):
9999
[ndarr1, ndarr2, ndarr3], axis=axis, cparams=cparams, urlpath="localfile.b2nd", mode="w"
100100
)
101101
np.testing.assert_almost_equal(result[:], nparray)
102+
# Remove localfile
103+
blosc2.remove_urlpath("localfile.b2nd")

tests/ndarray/test_lazyexpr.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,7 @@ def test_eval_getitem2():
10881088
np.testing.assert_allclose(expr[0], nres[0])
10891089
np.testing.assert_allclose(expr[1:, :7], nres[1:, :7])
10901090
np.testing.assert_allclose(expr[0:10:2], nres[0:10:2])
1091-
# This works, but it is not very efficient since it relies on blosc2.ndarray.slice for non-unit steps
1091+
# Now relies on inefficient blosc2.ndarray.slice for non-unit steps but only per chunk (not for whole result)
10921092
np.testing.assert_allclose(expr.slice((slice(None, None, None), slice(0, 10, 2)))[:], nres[:, 0:10:2])
10931093

10941094
# Small test for broadcasting
@@ -1097,21 +1097,21 @@ def test_eval_getitem2():
10971097
np.testing.assert_allclose(expr[0], nres[0])
10981098
np.testing.assert_allclose(expr[1:, :7], nres[1:, :7])
10991099
np.testing.assert_allclose(expr[:, 0:10:2], nres[:, 0:10:2])
1100-
# This works, but it is not very efficient since it relies on blosc2.ndarray.slice for non-unit steps
1100+
# Now relies on inefficient blosc2.ndarray.slice for non-unit steps but only per chunk (not for whole result)
11011101
np.testing.assert_allclose(expr.slice((slice(None, None, None), slice(0, 10, 2)))[:], nres[:, 0:10:2])
11021102

11031103

11041104
# Test lazyexpr's slice method
11051105
def test_eval_slice(array_fixture):
11061106
a1, a2, a3, a4, na1, na2, na3, na4 = array_fixture
11071107
expr = blosc2.lazyexpr("a1 + a2 - (a3 * a4)", operands={"a1": a1, "a2": a2, "a3": a3, "a4": a4})
1108-
nres = ne_evaluate("na1 + na2 - (na3 * na4)")[:2]
1109-
res = expr.slice(slice(0, 2))
1108+
nres = ne_evaluate("na1 + na2 - (na3 * na4)")
1109+
res = expr.slice(slice(0, 8, 2))
11101110
assert isinstance(res, blosc2.ndarray.NDArray)
1111-
np.testing.assert_allclose(res[:], nres)
1112-
res = expr[:2]
1111+
np.testing.assert_allclose(res[:], nres[:8:2])
1112+
res = expr[:8:2]
11131113
assert isinstance(res, np.ndarray)
1114-
np.testing.assert_allclose(res, nres)
1114+
np.testing.assert_allclose(res, nres[:8:2])
11151115

11161116
# string lazy expressions automatically use .slice internally
11171117
expr1 = blosc2.lazyexpr("a1 * a2", operands={"a1": a1, "a2": a2})
@@ -1123,6 +1123,18 @@ def test_eval_slice(array_fixture):
11231123
np.testing.assert_allclose(res[()], nres)
11241124

11251125

1126+
def test_rebasing(array_fixture):
1127+
a1, a2, a3, a4, na1, na2, na3, na4 = array_fixture
1128+
expr = blosc2.lazyexpr("a1 + a2 - (a3 * a4)", operands={"a1": a1, "a2": a2, "a3": a3, "a4": a4})
1129+
assert expr.expression == "(o0 + o1 - o2 * o3)"
1130+
1131+
expr = blosc2.lazyexpr("a1")
1132+
assert expr.expression == "o0"
1133+
1134+
expr = blosc2.lazyexpr("a1[:10]")
1135+
assert expr.expression == "o0.slice((slice(None, 10, None),))"
1136+
1137+
11261138
# Test get_chunk method
11271139
@pytest.mark.heavy
11281140
def test_get_chunk(array_fixture):
@@ -1457,6 +1469,10 @@ def test_chain_persistentexpressions():
14571469
myle4 = blosc2.open("expr4.b2nd")
14581470
assert (myle4[:] == le4[:]).all()
14591471

1472+
# Remove files
1473+
for f in ["expr1.b2nd", "expr2.b2nd", "expr3.b2nd", "expr4.b2nd", "a.b2nd", "b.b2nd", "c.b2nd"]:
1474+
blosc2.remove_urlpath(f)
1475+
14601476

14611477
@pytest.mark.parametrize(
14621478
"values",

tests/ndarray/test_lazyexpr_fields.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,3 +671,6 @@ def test_fields_indexing():
671671
gotitem = expr[0] # gives an error
672672
np.testing.assert_array_equal(sliced[()], gotitem)
673673
np.testing.assert_array_equal(gotitem, temp[0])
674+
675+
# Remove file
676+
blosc2.remove_urlpath("sa-1M.b2nd")

0 commit comments

Comments
 (0)