Skip to content

Commit 6d73f01

Browse files
keewisdcherian
andauthored
slicing a slice with an array without expanding the slice (#10580)
* use `is_full_slice` * don't try to expand a full slice * slice a slice with an array without materializing the slice * formatting * type hints * check that the new algorithm works properly * compare to `size` instead * doctests * type hints * stricter type hints * check against `numpy` Co-authored-by: Deepak Cherian <[email protected]> * remove the old parametrized expected values * fix type hints * shortcut for existing full slices * move the definition of the new types out of the type checking block * also support negative array values --------- Co-authored-by: Deepak Cherian <[email protected]>
1 parent 69316e5 commit 6d73f01

File tree

2 files changed

+94
-11
lines changed

2 files changed

+94
-11
lines changed

xarray/core/indexing.py

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
is_allowed_extension_array_dtype,
2929
is_duck_array,
3030
is_duck_dask_array,
31+
is_full_slice,
3132
is_scalar,
3233
is_valid_numpy_dtype,
3334
to_0d_array,
@@ -43,6 +44,9 @@
4344
from xarray.namedarray._typing import _Shape, duckarray
4445
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint
4546

47+
BasicIndexerType = int | np.integer | slice
48+
OuterIndexerType = BasicIndexerType | np.ndarray[Any, np.dtype[np.integer]]
49+
4650

4751
@dataclass
4852
class IndexSelResult:
@@ -300,19 +304,83 @@ def slice_slice(old_slice: slice, applied_slice: slice, size: int) -> slice:
300304
return slice(start, stop, step)
301305

302306

303-
def _index_indexer_1d(old_indexer, applied_indexer, size: int):
304-
if isinstance(applied_indexer, slice) and applied_indexer == slice(None):
307+
def normalize_array(
308+
array: np.ndarray[Any, np.dtype[np.integer]], size: int
309+
) -> np.ndarray[Any, np.dtype[np.integer]]:
310+
"""
311+
Ensure that the given array only contains positive values.
312+
313+
Examples
314+
--------
315+
>>> normalize_array(np.array([-1, -2, -3, -4]), 10)
316+
array([9, 8, 7, 6])
317+
>>> normalize_array(np.array([-5, 3, 5, -1, 8]), 12)
318+
array([ 7, 3, 5, 11, 8])
319+
"""
320+
if np.issubdtype(array.dtype, np.unsignedinteger):
321+
return array
322+
323+
return np.where(array >= 0, array, array + size)
324+
325+
326+
def slice_slice_by_array(
327+
old_slice: slice,
328+
array: np.ndarray[Any, np.dtype[np.integer]],
329+
size: int,
330+
) -> np.ndarray[Any, np.dtype[np.integer]]:
331+
"""Given a slice and the size of the dimension to which it will be applied,
332+
index it with an array to return a new array equivalent to applying
333+
the slices sequentially
334+
335+
Examples
336+
--------
337+
>>> slice_slice_by_array(slice(2, 10), np.array([1, 3, 5]), 12)
338+
array([3, 5, 7])
339+
>>> slice_slice_by_array(slice(1, None, 2), np.array([1, 3, 7, 8]), 20)
340+
array([ 3, 7, 15, 17])
341+
>>> slice_slice_by_array(slice(None, None, -1), np.array([2, 4, 7]), 20)
342+
array([17, 15, 12])
343+
"""
344+
# to get a concrete slice, limited to the size of the array
345+
normalized_slice = normalize_slice(old_slice, size)
346+
347+
size_after_slice = len(range(*normalized_slice.indices(size)))
348+
normalized_array = normalize_array(array, size_after_slice)
349+
350+
new_indexer = normalized_array * normalized_slice.step + normalized_slice.start
351+
352+
if np.any(new_indexer >= size):
353+
raise IndexError("indices out of bounds") # TODO: more helpful error message
354+
355+
return new_indexer
356+
357+
358+
def _index_indexer_1d(
359+
old_indexer: OuterIndexerType,
360+
applied_indexer: OuterIndexerType,
361+
size: int,
362+
) -> OuterIndexerType:
363+
if is_full_slice(applied_indexer):
305364
# shortcut for the usual case
306365
return old_indexer
366+
if is_full_slice(old_indexer):
367+
# shortcut for full slices
368+
return applied_indexer
369+
370+
indexer: OuterIndexerType
307371
if isinstance(old_indexer, slice):
308372
if isinstance(applied_indexer, slice):
309373
indexer = slice_slice(old_indexer, applied_indexer, size)
310374
elif isinstance(applied_indexer, integer_types):
311-
indexer = range(*old_indexer.indices(size))[applied_indexer] # type: ignore[assignment]
375+
indexer = range(*old_indexer.indices(size))[applied_indexer]
312376
else:
313-
indexer = _expand_slice(old_indexer, size)[applied_indexer]
314-
else:
377+
indexer = slice_slice_by_array(old_indexer, applied_indexer, size)
378+
elif isinstance(old_indexer, np.ndarray):
315379
indexer = old_indexer[applied_indexer]
380+
else:
381+
# should be unreachable
382+
raise ValueError("cannot index integers. Please open an issuec-")
383+
316384
return indexer
317385

318386

@@ -389,7 +457,7 @@ class BasicIndexer(ExplicitIndexer):
389457

390458
__slots__ = ()
391459

392-
def __init__(self, key: tuple[int | np.integer | slice, ...]):
460+
def __init__(self, key: tuple[BasicIndexerType, ...]):
393461
if not isinstance(key, tuple):
394462
raise TypeError(f"key must be a tuple: {key!r}")
395463

@@ -421,9 +489,7 @@ class OuterIndexer(ExplicitIndexer):
421489

422490
def __init__(
423491
self,
424-
key: tuple[
425-
int | np.integer | slice | np.ndarray[Any, np.dtype[np.generic]], ...
426-
],
492+
key: tuple[BasicIndexerType | np.ndarray[Any, np.dtype[np.generic]], ...],
427493
):
428494
if not isinstance(key, tuple):
429495
raise TypeError(f"key must be a tuple: {key!r}")
@@ -629,7 +695,8 @@ def __init__(self, array: Any, key: ExplicitIndexer | None = None):
629695

630696
def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer:
631697
iter_new_key = iter(expanded_indexer(new_key.tuple, self.ndim))
632-
full_key = []
698+
699+
full_key: list[OuterIndexerType] = []
633700
for size, k in zip(self.array.shape, self.key.tuple, strict=True):
634701
if isinstance(k, integer_types):
635702
full_key.append(k)
@@ -638,7 +705,7 @@ def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer:
638705
full_key_tuple = tuple(full_key)
639706

640707
if all(isinstance(k, integer_types + (slice,)) for k in full_key_tuple):
641-
return BasicIndexer(full_key_tuple)
708+
return BasicIndexer(cast(tuple[BasicIndexerType, ...], full_key_tuple))
642709
return OuterIndexer(full_key_tuple)
643710

644711
@property

xarray/tests/test_indexing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,22 @@ def test_slice_slice(self) -> None:
305305
actual = x[new_slice]
306306
assert_array_equal(expected, actual)
307307

308+
@pytest.mark.parametrize(
309+
["old_slice", "array", "size"],
310+
(
311+
(slice(None, 8), np.arange(2, 6), 10),
312+
(slice(2, None), np.arange(2, 6), 10),
313+
(slice(1, 10, 2), np.arange(1, 4), 15),
314+
(slice(10, None, -1), np.array([2, 5, 7]), 12),
315+
(slice(2, None, 2), np.array([3, -2, 5, -1]), 13),
316+
(slice(8, None), np.array([1, -2, 2, -1, -7]), 20),
317+
),
318+
)
319+
def test_slice_slice_by_array(self, old_slice, array, size):
320+
actual = indexing.slice_slice_by_array(old_slice, array, size)
321+
expected = np.arange(size)[old_slice][array]
322+
assert_array_equal(actual, expected)
323+
308324
def test_lazily_indexed_array(self) -> None:
309325
original = np.random.rand(10, 20, 30)
310326
x = indexing.NumpyIndexingAdapter(original)

0 commit comments

Comments
 (0)