Skip to content

Commit dfe6c92

Browse files
Merge pull request #430 from Blosc/slices_eval_getitem
New slices_eval_getitem() for accelerating expression evaluation with __getitem__
2 parents 6cccc23 + c94b85c commit dfe6c92

File tree

4 files changed

+298
-18
lines changed

4 files changed

+298
-18
lines changed

bench/ndarray/slice-expr.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Imports
2+
import numpy as np
3+
import blosc2
4+
import time
5+
from memory_profiler import memory_usage
6+
import matplotlib.pyplot as plt
7+
8+
file = "dset-ones.b2nd"
9+
# a = blosc2.open(file)
10+
# expr = blosc2.where(a < 5, a * 2**14, a)
11+
d = 160
12+
shape = (d,) * 4
13+
chunks = (d // 4,) * 4
14+
blocks = (d // 10,) * 4
15+
print(f"Creating a 4D array of shape {shape} with chunks {chunks} and blocks {blocks}...")
16+
t = time.time()
17+
#a = blosc2.linspace(0, d, num=d**4, shape=(d,) * 4, blocks=(d//10,) * 4, chunks=(d//2,) * 4, urlpath=file, mode="w")
18+
#a = blosc2.linspace(0, d, num = d**4, shape=(d,)*4, blocks=(d//10,)*4, chunks=(d//2,)*4)
19+
# a = blosc2.arange(0, d**4, shape=(d,) * 4, blocks=(d//10,) * 4, chunks=(d//2,) * 4, urlpath=file, mode="w")
20+
a = blosc2.ones(shape=shape, chunks=chunks, blocks=blocks) #, urlpath=file, mode="w")
21+
t = time.time() - t
22+
print(f"Time to create array: {t:.6f} seconds")
23+
t = time.time()
24+
#expr = a * 30
25+
expr = a * 2
26+
print(f"Time to create expression: {time.time() - t:.6f} seconds")
27+
28+
# dim0
29+
def slice_dim0():
30+
t = time.time()
31+
res = expr[1]
32+
t0 = time.time() - t
33+
print(f"Time to access dim0: {t0:.6f} seconds")
34+
35+
# dim1
36+
def slice_dim1():
37+
t = time.time()
38+
res = expr[:,1]
39+
t1 = time.time() - t
40+
print(f"Time to access dim1: {t1:.6f} seconds")
41+
42+
# dim2
43+
def slice_dim2():
44+
t = time.time()
45+
res = expr[:,:,1]
46+
t2 = time.time() - t
47+
print(f"Time to access dim2: {t2:.6f} seconds")
48+
49+
# dim3
50+
def slice_dim3():
51+
t = time.time()
52+
res = expr[:,:,:,1]
53+
#res = expr[1]
54+
t3 = time.time() - t
55+
56+
print(f"Time to access dim3: {t3:.6f} seconds")
57+
58+
fig = plt.figure()
59+
interval = 0.001
60+
offset = 0
61+
for f in [slice_dim0, slice_dim1, slice_dim2, slice_dim3]:
62+
mem = memory_usage((f,), interval=interval)
63+
times = offset + interval * np.arange(len(mem))
64+
offset = times[-1]
65+
plt.plot(times, mem)
66+
67+
plt.xlabel('Time (s)')
68+
plt.ylabel('Memory usage (MiB)')
69+
plt.title('Memory usage over time for slicing operations, slice-expr.py')
70+
plt.legend(['dim0', 'dim1', 'dim2', 'dim3'])
71+
plt.savefig('plots/slice-expr.png', format="png")

src/blosc2/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ class Tuner(Enum):
138138
"""
139139
Maximum buffer size in bytes for a Blosc2 chunk."""
140140

141+
MAX_FAST_PATH_SIZE = 2**30
142+
"""
143+
Maximum size in bytes for a fast path evaluation.
144+
"""
145+
141146
MAX_OVERHEAD = MAX_OVERHEAD
142147
"""
143148
Maximum overhead during compression (in bytes). This is

src/blosc2/lazyexpr.py

Lines changed: 193 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,66 @@
4646
import numexpr
4747

4848

49+
def compute_slice_shape(shape, slice_obj, dont_squeeze=False): # noqa: C901
50+
# Handle None or empty slice case
51+
if slice_obj is None or slice_obj == ():
52+
return shape
53+
54+
# Use ndindex to handle slice calculations
55+
try:
56+
idx = ndindex.ndindex(slice_obj).expand(shape)
57+
return idx.shape
58+
except Exception:
59+
# Fall back to manual processing
60+
if not isinstance(slice_obj, tuple):
61+
slice_obj = (slice_obj,)
62+
63+
result = []
64+
shape_idx = 0
65+
dims_reduced = 0
66+
67+
# Process slice components
68+
for i, s in enumerate(slice_obj):
69+
if i >= len(shape):
70+
break
71+
72+
if isinstance(s, slice):
73+
start = 0 if s.start is None else max(0, s.start if s.start >= 0 else shape[i] + s.start)
74+
stop = (
75+
shape[i]
76+
if s.stop is None
77+
else min(shape[i], s.stop if s.stop >= 0 else shape[i] + s.stop)
78+
)
79+
step = 1 if s.step is None else abs(s.step)
80+
81+
if start < stop:
82+
result.append((stop - start - 1) // step + 1)
83+
else:
84+
result.append(0)
85+
shape_idx += 1
86+
elif isinstance(s, int) or np.isscalar(s):
87+
if dont_squeeze:
88+
result.append(1)
89+
shape_idx += 1
90+
else:
91+
# Integer indexing reduces dimensionality
92+
dims_reduced += 1
93+
shape_idx += 1
94+
continue
95+
elif s is Ellipsis:
96+
# Fill in with remaining dimensions
97+
remaining_dims = len(shape) - (len(slice_obj) - 1 + dims_reduced)
98+
result.extend(shape[shape_idx : shape_idx + remaining_dims])
99+
shape_idx += remaining_dims
100+
continue
101+
102+
# Add any remaining dimensions
103+
if shape_idx < len(shape):
104+
result.extend(shape[shape_idx:])
105+
106+
return tuple(result)
107+
108+
49109
def ne_evaluate(expression, local_dict=None, **kwargs):
50110
"""Safely evaluate expressions using numexpr when possible, falling back to numpy."""
51111
if local_dict is None:
@@ -505,15 +565,18 @@ def compute_broadcast_shape(arrays):
505565
return np.broadcast_shapes(*shapes) if shapes else None
506566

507567

508-
def check_smaller_shape(value, shape, slice_shape):
568+
def check_smaller_shape(value_shape, shape, slice_shape):
509569
"""Check whether the shape of the value is smaller than the shape of the array.
510570
511571
This follows the NumPy broadcasting rules.
512572
"""
573+
# slice_shape must be as long as shape
574+
if len(slice_shape) != len(shape):
575+
raise ValueError("slice_shape must be as long as shape")
513576
is_smaller_shape = any(
514-
s > (1 if i >= len(value.shape) else value.shape[i]) for i, s in enumerate(slice_shape)
577+
s > (1 if i >= len(value_shape) else value_shape[i]) for i, s in enumerate(slice_shape)
515578
)
516-
return len(value.shape) < len(shape) or is_smaller_shape
579+
return len(value_shape) < len(shape) or is_smaller_shape
517580

518581

519582
def _compute_smaller_slice(larger_shape, smaller_shape, larger_slice):
@@ -1304,8 +1367,9 @@ def slices_eval( # noqa: C901
13041367
operands: dict
13051368
A dictionary containing the operands for the expression.
13061369
getitem: bool, optional
1307-
Indicates whether the expression is being evaluated for a getitem operation.
1308-
_slice: slice, list of slices, optional
1370+
Indicates whether the expression is being evaluated for a getitem operation or compute().
1371+
Default is False.
1372+
_slice: int, slice, list of slices, optional
13091373
If provided, only the chunks that intersect with this slice
13101374
will be evaluated.
13111375
kwargs: Any, optional
@@ -1331,9 +1395,23 @@ def slices_eval( # noqa: C901
13311395
_order = [_order]
13321396

13331397
dtype = kwargs.pop("dtype", None)
1398+
shape_slice = None
1399+
_slice_step = False
13341400
if out is None:
13351401
# Compute the shape and chunks of the output array, including broadcasting
13361402
shape = compute_broadcast_shape(operands.values())
1403+
if _slice is not None:
1404+
# Remove the step parts from the slice, as code below does not support it
1405+
# First ensure _slice is a tuple, even if it's a single slice
1406+
_slice_ = _slice if isinstance(_slice, tuple) else (_slice,)
1407+
# Check whether _slice_ contains any step that are not None or 1
1408+
if any(isinstance(s, slice) and s.step not in (None, 1) for s in _slice_):
1409+
_slice_step = True
1410+
_slice_ = tuple(
1411+
slice(s.start or 0, s.stop or shape[i], None) if isinstance(s, slice) else s
1412+
for i, s in enumerate(_slice_)
1413+
)
1414+
shape_slice = compute_slice_shape(shape, _slice_, dont_squeeze=True)
13371415
else:
13381416
shape = out.shape
13391417

@@ -1412,7 +1490,7 @@ def slices_eval( # noqa: C901
14121490
if value.shape == ():
14131491
chunk_operands[key] = value[()]
14141492
continue
1415-
if check_smaller_shape(value, shape, slice_shape):
1493+
if check_smaller_shape(value.shape, shape, slice_shape):
14161494
# We need to fetch the part of the value that broadcasts with the operand
14171495
smaller_slice = compute_smaller_slice(shape, value.shape, slice_)
14181496
chunk_operands[key] = value[smaller_slice]
@@ -1476,18 +1554,19 @@ def slices_eval( # noqa: C901
14761554
raise ValueError("The where condition must be a tuple with one or two elements")
14771555

14781556
if out is None:
1479-
shape_ = shape
1557+
shape_ = shape_slice if shape_slice is not None else shape
14801558
if where is not None and len(where) < 2:
14811559
# The result is a linear array
1482-
shape_ = math.prod(shape)
1560+
shape_ = math.prod(shape_)
14831561
if getitem or _order:
14841562
out = np.empty(shape_, dtype=dtype_)
14851563
if _order:
14861564
indices_ = np.empty(shape_, dtype=np.int64)
14871565
else:
14881566
if "chunks" not in kwargs and (where is None or len(where) == 2):
14891567
# Let's use the same chunks as the first operand (it could have been automatic too)
1490-
out = blosc2.empty(shape_, chunks=chunks, dtype=dtype_, **kwargs)
1568+
# out = blosc2.empty(shape_, chunks=chunks, dtype=dtype_, **kwargs)
1569+
out = blosc2.empty(shape_, dtype=dtype_, **kwargs)
14911570
elif "chunks" in kwargs and (where is not None and len(where) < 2 and len(shape_) > 1):
14921571
# Remove the chunks argument if the where condition is not a tuple with two elements
14931572
kwargs.pop("chunks")
@@ -1527,19 +1606,107 @@ def slices_eval( # noqa: C901
15271606
else: # Need to take orig_slice since filled up array according to slice_ for each chunk
15281607
if orig_slice is not None:
15291608
if isinstance(out, np.ndarray):
1530-
out = out[orig_slice]
1531-
if _order is not None:
1532-
indices_ = indices_[orig_slice]
1609+
out = np.squeeze(out)
1610+
if _slice_step:
1611+
out = out[orig_slice]
15331612
elif isinstance(out, blosc2.NDArray):
15341613
# It *seems* better to choose an automatic chunks and blocks for the output array
15351614
# out = out.slice(orig_slice, chunks=out.chunks, blocks=out.blocks)
1536-
out = out.slice(orig_slice)
1615+
out = out.squeeze()
1616+
if _slice_step:
1617+
out = out.slice(orig_slice)
15371618
else:
15381619
raise ValueError("The output array is not a NumPy array or a NDArray")
15391620

15401621
return out
15411622

15421623

1624+
def slices_eval_getitem(
1625+
expression: str,
1626+
operands: dict,
1627+
_slice=None,
1628+
**kwargs,
1629+
) -> np.ndarray:
1630+
"""Evaluate the expression in slices of operands.
1631+
1632+
This function can handle operands with different chunk shapes and
1633+
can evaluate only a slice of the output array if needed.
1634+
1635+
This is a special (and much simplified) version of slices_eval() that
1636+
only works for the case we are returning a NumPy array, where is
1637+
either None or has two args, and expression is not callable.
1638+
1639+
One inconvenient of this function is that it tries to evaluate
1640+
the whole slice in one go. For small slices, this is good, as it
1641+
is normally way more efficient. However, for larger slices this
1642+
can require large amounts of memory per operand.
1643+
1644+
Parameters
1645+
----------
1646+
expression: str or callable
1647+
The expression or user-defined (udf) to evaluate.
1648+
operands: dict
1649+
A dictionary containing the operands for the expression.
1650+
_slice: int, slice, list of slices, optional
1651+
If provided, this slice will be evaluated.
1652+
kwargs: Any, optional
1653+
Additional keyword arguments that are supported by the :func:`empty` constructor.
1654+
1655+
Returns
1656+
-------
1657+
:ref:`NDArray` or np.ndarray
1658+
The output array.
1659+
"""
1660+
out: np.ndarray | None = kwargs.pop("_output", None)
1661+
ne_args: dict = kwargs.pop("_ne_args", {})
1662+
if ne_args is None:
1663+
ne_args = {}
1664+
where: dict | None = kwargs.pop("_where_args", None)
1665+
1666+
dtype = kwargs.pop("dtype", None)
1667+
if out is None:
1668+
# Compute the shape and chunks of the output array, including broadcasting
1669+
shape = compute_broadcast_shape(operands.values())
1670+
else:
1671+
shape = out.shape
1672+
1673+
# compute the shape of the output array, broadcasting-compatible
1674+
_slice = ndindex.ndindex(_slice).expand(shape).raw # make sure slice is tuple
1675+
_slice_bcast = tuple(slice(i, i + 1) if isinstance(i, int) else i for i in _slice)
1676+
slice_shape = compute_slice_shape(shape, _slice_bcast)
1677+
1678+
# Get the slice of each operand
1679+
slice_operands = {}
1680+
for key, value in operands.items():
1681+
if np.isscalar(value):
1682+
slice_operands[key] = value
1683+
continue
1684+
if value.shape == ():
1685+
slice_operands[key] = value[()]
1686+
continue
1687+
if check_smaller_shape(value.shape, shape, slice_shape):
1688+
# We need to fetch the part of the value that broadcasts with the operand
1689+
smaller_slice = compute_smaller_slice(shape, value.shape, _slice)
1690+
slice_operands[key] = value[smaller_slice]
1691+
continue
1692+
1693+
slice_operands[key] = value[_slice]
1694+
1695+
# Evaluate the expression using slices of operands
1696+
if where is None:
1697+
result = ne_evaluate(expression, slice_operands, **ne_args)
1698+
else:
1699+
# Apply the where condition (in result)
1700+
new_expr = f"where({expression}, _where_x, _where_y)"
1701+
result = ne_evaluate(new_expr, slice_operands, **ne_args)
1702+
1703+
if out is None: # avoid copying unnecessarily
1704+
return result.astype(dtype, copy=False) if dtype else result
1705+
else:
1706+
out[()] = result
1707+
return out
1708+
1709+
15431710
def infer_reduction_dtype(dtype, operation):
15441711
# It may change in the future, but for now, this mimics NumPy's (2.1) behavior pretty well
15451712
if operation in {ReduceOp.SUM, ReduceOp.PROD}:
@@ -1574,7 +1741,7 @@ def reduce_slices( # noqa: C901
15741741
A dictionary containing the operands for the operands.
15751742
reduce_args: dict
15761743
A dictionary with arguments to be passed to the reduction function.
1577-
_slice: slice, list of slices, optional
1744+
_slice: int, slice, list of slices, optional
15781745
If provided, only the chunks that intersect with this slice
15791746
will be evaluated.
15801747
kwargs: Any, optional
@@ -1666,6 +1833,7 @@ def reduce_slices( # noqa: C901
16661833
reduced_slice = tuple(sl for i, sl in enumerate(slice_) if i not in axis)
16671834
offset = tuple(s.start for s in slice_) # offset for the udf
16681835
# Check whether current slice_ intersects with _slice
1836+
# TODO: Is this necessary, shouldn't slice always be None for a reduction?
16691837
if _slice is not None and _slice != ():
16701838
# Ensure that slices do not have any None as start or stop
16711839
_slice = tuple(slice(s.start or 0, s.stop or shape[i], s.step) for i, s in enumerate(_slice))
@@ -1698,7 +1866,7 @@ def reduce_slices( # noqa: C901
16981866
if value.shape == ():
16991867
chunk_operands[key] = value[()]
17001868
continue
1701-
if check_smaller_shape(value, shape, chunks_):
1869+
if check_smaller_shape(value.shape, shape, chunks_):
17021870
# We need to fetch the part of the value that broadcasts with the operand
17031871
smaller_slice = compute_smaller_slice(operand.shape, value.shape, slice_)
17041872
chunk_operands[key] = value[smaller_slice]
@@ -1869,8 +2037,16 @@ def chunked_eval( # noqa: C901
18692037
return reduce_slices(expression, operands, reduce_args=reduce_args, _slice=item, **kwargs)
18702038

18712039
if not is_full_slice(item) or (where is not None and len(where) < 2):
1872-
# The fast path is not possible when using partial slices or where returning
1873-
# a variable number of elements
2040+
# The fast path is possible under a few conditions
2041+
if getitem and (where is None or len(where) == 2) and not callable(expression):
2042+
# Compute the size of operands for the fast path
2043+
shape = compute_broadcast_shape(operands.values())
2044+
shape_operands = compute_slice_shape(shape, item)
2045+
_dtype = kwargs.get("dtype", np.float64)
2046+
size_operands = math.prod(shape_operands) * len(operands) * _dtype.itemsize
2047+
# Only take the fast path if the size of operands is relatively small
2048+
if size_operands < blosc2.MAX_FAST_PATH_SIZE:
2049+
return slices_eval_getitem(expression, operands, _slice=item, **kwargs)
18742050
return slices_eval(expression, operands, getitem=getitem, _slice=item, **kwargs)
18752051

18762052
if fast_path:

0 commit comments

Comments
 (0)