Skip to content

Commit 808e891

Browse files
author
Luke Shaw
committed
Fix indexing for lazy expressions, and allow use of None in getitem
1 parent e1af9b4 commit 808e891

File tree

3 files changed

+24
-13
lines changed

3 files changed

+24
-13
lines changed

src/blosc2/lazyexpr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2440,16 +2440,16 @@ def _compute_expr(self, item, kwargs): # noqa: C901
24402440
if len(self._where_args) == 1:
24412441
# We have a single argument
24422442
where_x = self._where_args["_where_x"]
2443-
return where_x[:][lazy_expr]
2443+
return (where_x[:][lazy_expr])[item]
24442444
if len(self._where_args) == 2:
24452445
# We have two arguments
24462446
where_x = self._where_args["_where_x"]
24472447
where_y = self._where_args["_where_y"]
2448-
return np.where(lazy_expr, where_x, where_y)
2448+
return np.where(lazy_expr, where_x, where_y)[item]
24492449
if hasattr(self, "_output"):
24502450
# This is not exactly optimized, but it works for now
2451-
self._output[:] = lazy_expr
2452-
return lazy_expr
2451+
self._output[:] = lazy_expr[item]
2452+
return lazy_expr[item]
24532453

24542454
return chunked_eval(lazy_expr.expression, lazy_expr.operands, item, **kwargs)
24552455

src/blosc2/ndarray.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,15 @@ def make_key_hashable(key):
6161
else:
6262
return key
6363

64-
6564
def process_key(key, shape):
66-
if key is None:
67-
key = tuple(slice(None) for _ in range(len(shape)))
6865
key = ndindex.ndindex(key).expand(shape).raw
69-
mask = tuple(isinstance(k, int) for k in key)
70-
key = tuple(k if isinstance(k, slice) else slice(k, k + 1, None) for k in key)
71-
return key, mask
66+
key_ = ()
67+
mask = () #get integer indices where dimension collapses
68+
for k in key: #handle multiple Nones in key
69+
if k is not None:
70+
key_ += (k if isinstance(k, slice) else slice(k, k + 1, None),)
71+
mask += (isinstance(k, int),)
72+
return key_, mask
7273

7374

7475
def get_ndarray_start_stop(ndim, key, shape):
@@ -1469,6 +1470,7 @@ def __getitem__( # noqa: C901
14691470
[3.3333, 3.3333, 3.3333, 3.3333, 3.3333]])
14701471
"""
14711472
# First try some fast paths for common cases
1473+
newaxes = None
14721474
if isinstance(key, np.integer):
14731475
# Massage the key to a tuple and go the fast path
14741476
key_ = (slice(key, key + 1), *(slice(None),) * (self.ndim - 1))
@@ -1519,6 +1521,8 @@ def __getitem__( # noqa: C901
15191521
start, stop, step = get_ndarray_start_stop(self.ndim, key_, self.shape)
15201522
shape = np.array([sp - st for st, sp in zip(start, stop, strict=True)])
15211523
shape = tuple(shape[[not m for m in mask]])
1524+
# Add new axes if necessary
1525+
newaxes = tuple(i for i, k in enumerate(key) if k is None)
15221526

15231527
# Create the array to store the result
15241528
arr = np.empty(shape, dtype=self.dtype)
@@ -1533,9 +1537,9 @@ def __getitem__( # noqa: C901
15331537
self._last_read.clear()
15341538
inmutable_key = make_key_hashable(key)
15351539
self._last_read[inmutable_key] = nparr
1536-
1537-
return nparr
1538-
1540+
1541+
return nparr.expand_dims(newaxes) if newaxes is not None else nparr
1542+
15391543
def __setitem__(self, key: int | slice | Sequence[slice], value: object):
15401544
"""Set a slice of the array.
15411545

tests/ndarray/test_reductions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,10 @@ def test_save_constructor_reduce2(shape, disk, compute):
418418
blosc2.remove_urlpath(urlpath_a)
419419
blosc2.remove_urlpath(urlpath_b)
420420
blosc2.remove_urlpath("out.b2nd")
421+
422+
def test_reduction_index():
423+
shape = (20, 20)
424+
a = blosc2.linspace(0, 20, num=np.prod(shape), shape=shape)
425+
expr = blosc2.sum(a, axis=0)
426+
assert expr[:10].shape == (10,)
427+
assert expr[0].shape == ()

0 commit comments

Comments
 (0)