diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index c98175578f8..7d7f9335cb2 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -28,6 +28,7 @@ is_allowed_extension_array_dtype, is_duck_array, is_duck_dask_array, + is_full_slice, is_scalar, is_valid_numpy_dtype, to_0d_array, @@ -43,6 +44,9 @@ from xarray.namedarray._typing import _Shape, duckarray from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint +BasicIndexerType = int | np.integer | slice +OuterIndexerType = BasicIndexerType | np.ndarray[Any, np.dtype[np.integer]] + @dataclass class IndexSelResult: @@ -300,19 +304,83 @@ def slice_slice(old_slice: slice, applied_slice: slice, size: int) -> slice: return slice(start, stop, step) -def _index_indexer_1d(old_indexer, applied_indexer, size: int): - if isinstance(applied_indexer, slice) and applied_indexer == slice(None): +def normalize_array( + array: np.ndarray[Any, np.dtype[np.integer]], size: int +) -> np.ndarray[Any, np.dtype[np.integer]]: + """ + Ensure that the given array only contains positive values. + + Examples + -------- + >>> normalize_array(np.array([-1, -2, -3, -4]), 10) + array([9, 8, 7, 6]) + >>> normalize_array(np.array([-5, 3, 5, -1, 8]), 12) + array([ 7, 3, 5, 11, 8]) + """ + if np.issubdtype(array.dtype, np.unsignedinteger): + return array + + return np.where(array >= 0, array, array + size) + + +def slice_slice_by_array( + old_slice: slice, + array: np.ndarray[Any, np.dtype[np.integer]], + size: int, +) -> np.ndarray[Any, np.dtype[np.integer]]: + """Given a slice and the size of the dimension to which it will be applied, + index it with an array to return a new array equivalent to applying + the slices sequentially + + Examples + -------- + >>> slice_slice_by_array(slice(2, 10), np.array([1, 3, 5]), 12) + array([3, 5, 7]) + >>> slice_slice_by_array(slice(1, None, 2), np.array([1, 3, 7, 8]), 20) + array([ 3, 7, 15, 17]) + >>> slice_slice_by_array(slice(None, None, -1), np.array([2, 4, 7]), 20) + array([17, 15, 12]) + """ + # to get a concrete slice, limited to the size of the array + normalized_slice = normalize_slice(old_slice, size) + + size_after_slice = len(range(*normalized_slice.indices(size))) + normalized_array = normalize_array(array, size_after_slice) + + new_indexer = normalized_array * normalized_slice.step + normalized_slice.start + + if np.any(new_indexer >= size): + raise IndexError("indices out of bounds") # TODO: more helpful error message + + return new_indexer + + +def _index_indexer_1d( + old_indexer: OuterIndexerType, + applied_indexer: OuterIndexerType, + size: int, +) -> OuterIndexerType: + if is_full_slice(applied_indexer): # shortcut for the usual case return old_indexer + if is_full_slice(old_indexer): + # shortcut for full slices + return applied_indexer + + indexer: OuterIndexerType if isinstance(old_indexer, slice): if isinstance(applied_indexer, slice): indexer = slice_slice(old_indexer, applied_indexer, size) elif isinstance(applied_indexer, integer_types): - indexer = range(*old_indexer.indices(size))[applied_indexer] # type: ignore[assignment] + indexer = range(*old_indexer.indices(size))[applied_indexer] else: - indexer = _expand_slice(old_indexer, size)[applied_indexer] - else: + indexer = slice_slice_by_array(old_indexer, applied_indexer, size) + elif isinstance(old_indexer, np.ndarray): indexer = old_indexer[applied_indexer] + else: + # should be unreachable + raise ValueError("cannot index integers. Please open an issuec-") + return indexer @@ -389,7 +457,7 @@ class BasicIndexer(ExplicitIndexer): __slots__ = () - def __init__(self, key: tuple[int | np.integer | slice, ...]): + def __init__(self, key: tuple[BasicIndexerType, ...]): if not isinstance(key, tuple): raise TypeError(f"key must be a tuple: {key!r}") @@ -421,9 +489,7 @@ class OuterIndexer(ExplicitIndexer): def __init__( self, - key: tuple[ - int | np.integer | slice | np.ndarray[Any, np.dtype[np.generic]], ... - ], + key: tuple[BasicIndexerType | np.ndarray[Any, np.dtype[np.generic]], ...], ): if not isinstance(key, tuple): raise TypeError(f"key must be a tuple: {key!r}") @@ -629,7 +695,8 @@ def __init__(self, array: Any, key: ExplicitIndexer | None = None): def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer: iter_new_key = iter(expanded_indexer(new_key.tuple, self.ndim)) - full_key = [] + + full_key: list[OuterIndexerType] = [] for size, k in zip(self.array.shape, self.key.tuple, strict=True): if isinstance(k, integer_types): full_key.append(k) @@ -638,7 +705,7 @@ def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer: full_key_tuple = tuple(full_key) if all(isinstance(k, integer_types + (slice,)) for k in full_key_tuple): - return BasicIndexer(full_key_tuple) + return BasicIndexer(cast(tuple[BasicIndexerType, ...], full_key_tuple)) return OuterIndexer(full_key_tuple) @property diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 6dd75b58c6a..db4f6aaf0bd 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -305,6 +305,22 @@ def test_slice_slice(self) -> None: actual = x[new_slice] assert_array_equal(expected, actual) + @pytest.mark.parametrize( + ["old_slice", "array", "size"], + ( + (slice(None, 8), np.arange(2, 6), 10), + (slice(2, None), np.arange(2, 6), 10), + (slice(1, 10, 2), np.arange(1, 4), 15), + (slice(10, None, -1), np.array([2, 5, 7]), 12), + (slice(2, None, 2), np.array([3, -2, 5, -1]), 13), + (slice(8, None), np.array([1, -2, 2, -1, -7]), 20), + ), + ) + def test_slice_slice_by_array(self, old_slice, array, size): + actual = indexing.slice_slice_by_array(old_slice, array, size) + expected = np.arange(size)[old_slice][array] + assert_array_equal(actual, expected) + def test_lazily_indexed_array(self) -> None: original = np.random.rand(10, 20, 30) x = indexing.NumpyIndexingAdapter(original)