Skip to content

Commit 22b023b

Browse files
committed
Fix for new slices_eval2()
1 parent ba2753e commit 22b023b

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

src/blosc2/lazyexpr.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,12 +1669,23 @@ def slices_eval2( # noqa: C901
16691669

16701670
dtype = kwargs.pop("dtype", None)
16711671
shape_slice = None
1672+
_slice_step = False
16721673
if out is None:
16731674
# Compute the shape and chunks of the output array, including broadcasting
16741675
shape = compute_broadcast_shape(operands.values())
16751676
if _slice is not None:
16761677
# print("shape abans:", shape)
1677-
shape_slice = compute_slice_shape(shape, _slice, dont_squeeze=True)
1678+
# Remove the step parts from the slice, as code below does not support it
1679+
# First ensure _slice is a tuple, even if it's a single slice
1680+
_slice_ = _slice if isinstance(_slice, tuple) else (_slice,)
1681+
# Check whether _slice_ contains any step that are not None or 1
1682+
if any(isinstance(s, slice) and s.step not in (None, 1) for s in _slice_):
1683+
_slice_step = True
1684+
_slice_ = tuple(
1685+
slice(s.start or 0, s.stop or shape[i], None) if isinstance(s, slice) else s
1686+
for i, s in enumerate(_slice_)
1687+
)
1688+
shape_slice = compute_slice_shape(shape, _slice_, dont_squeeze=True)
16781689
# print("shape despres:", shape_slice)
16791690
else:
16801691
shape = out.shape
@@ -1726,15 +1737,12 @@ def slices_eval2( # noqa: C901
17261737
checker = _slice.item() if hasattr(_slice, "item") else _slice # can't use != when _slice is np.int
17271738
if checker is not None and checker != ():
17281739
# Ensure that _slice is of type slice
1729-
# print("_slice, shape:", _slice, shape)
17301740
key = ndindex.ndindex(_slice).expand(shape).raw
17311741
_slice = tuple(k if isinstance(k, slice) else slice(k, k + 1, None) for k in key)
17321742
# Ensure that slices do not have any None as start or stop
17331743
_slice = tuple(slice(s.start or 0, s.stop or shape[i], s.step) for i, s in enumerate(_slice))
17341744
slice_ = tuple(slice(s.start or 0, s.stop or shape[i], s.step) for i, s in enumerate(slice_))
17351745
intersects = do_slices_intersect(_slice, slice_)
1736-
# print("_slice:", _slice)
1737-
# print("slice_:", slice_)
17381746
if not intersects:
17391747
continue
17401748
# Compute the part of the slice_ that intersects with _slice
@@ -1748,9 +1756,6 @@ def slices_eval2( # noqa: C901
17481756
# Get the starts and stops for the slice
17491757
starts = [s.start if s.start is not None else 0 for s in slice_]
17501758
stops = [s.stop if s.stop is not None else sh for s, sh in zip(slice_, slice_shape, strict=True)]
1751-
# print("-->", slice_)
1752-
# print("starts:", starts)
1753-
# print("stops:", stops)
17541759

17551760
# Get the slice of each operand
17561761
for key, value in operands.items():
@@ -1877,11 +1882,14 @@ def slices_eval2( # noqa: C901
18771882
if orig_slice is not None:
18781883
if isinstance(out, np.ndarray):
18791884
out = np.squeeze(out)
1885+
if _slice_step:
1886+
out = out[orig_slice]
18801887
elif isinstance(out, blosc2.NDArray):
18811888
# It *seems* better to choose an automatic chunks and blocks for the output array
18821889
# out = out.slice(orig_slice, chunks=out.chunks, blocks=out.blocks)
1883-
# out = out.slice(orig_slice)
18841890
out = out.squeeze()
1891+
if _slice_step:
1892+
out = out.slice(orig_slice)
18851893
else:
18861894
raise ValueError("The output array is not a NumPy array or a NDArray")
18871895

@@ -2306,9 +2314,9 @@ def chunked_eval( # noqa: C901
23062314
if getitem and (where is None or len(where) == 2) and not callable(expression):
23072315
# If we are using getitem, we can still use some optimizations
23082316
return slices_eval_getitem(expression, operands, _slice=item, **kwargs)
2309-
# The next is an attempt to reduce memory consumption in a general way, but not working yet
2310-
# return slices_eval2(expression, operands, getitem=getitem, _slice=item, **kwargs)
2311-
return slices_eval(expression, operands, getitem=getitem, _slice=item, **kwargs)
2317+
# return slices_eval(expression, operands, getitem=getitem, _slice=item, **kwargs)
2318+
# The next is an improved version of slices_eval that consumes less memory
2319+
return slices_eval2(expression, operands, getitem=getitem, _slice=item, **kwargs)
23122320

23132321
if fast_path:
23142322
if getitem:
@@ -2322,7 +2330,9 @@ def chunked_eval( # noqa: C901
23222330
# a blosc2.NDArray
23232331
return fast_eval(expression, operands, getitem=False, **kwargs)
23242332

2325-
res = slices_eval(expression, operands, getitem=getitem, _slice=item, **kwargs)
2333+
# res = slices_eval(expression, operands, getitem=getitem, _slice=item, **kwargs)
2334+
# The next is an improved version of slices_eval that consumes less memory
2335+
res = slices_eval2(expression, operands, getitem=getitem, _slice=item, **kwargs)
23262336

23272337
finally:
23282338
# Deactivate cache for NDField instances

0 commit comments

Comments
 (0)