diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index eef1cc97da2..1a448dd5746 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -76,5 +76,3 @@ repos: rev: v1 hooks: - id: typos - # https://github.com/crate-ci/typos/issues/347 - pass_filenames: false diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 8e4458fb88f..cee3488db5b 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -26,6 +26,7 @@ get_valid_numpy_dtype, is_duck_array, is_duck_dask_array, + is_full_slice, is_scalar, is_valid_numpy_dtype, to_0d_array, @@ -41,6 +42,9 @@ from xarray.namedarray._typing import _Shape, duckarray from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint +BasicIndexerType = int | np.integer | slice +OuterIndexerType = BasicIndexerType | np.ndarray[Any, np.dtype[np.generic]] + @dataclass class IndexSelResult: @@ -298,18 +302,62 @@ def slice_slice(old_slice: slice, applied_slice: slice, size: int) -> slice: return slice(start, stop, step) -def _index_indexer_1d(old_indexer, applied_indexer, size: int): - if isinstance(applied_indexer, slice) and applied_indexer == slice(None): +def _index_indexer_1d( + old_indexer: OuterIndexerType | MultipleSlices, + applied_indexer: OuterIndexerType | MultipleSlices, + size: int, +) -> OuterIndexerType | MultipleSlices: + if is_full_slice(applied_indexer): # shortcut for the usual case return old_indexer + if is_full_slice(old_indexer): + return applied_indexer + + indexer: OuterIndexerType | MultipleSlices if isinstance(old_indexer, slice): if isinstance(applied_indexer, slice): indexer = slice_slice(old_indexer, applied_indexer, size) + elif isinstance(applied_indexer, MultipleSlices): + indexer = MultipleSlices.from_iterable( + slice_slice(old_indexer, s, size) for s in applied_indexer.slices + ).merge_slices() elif isinstance(applied_indexer, integer_types): - indexer = range(*old_indexer.indices(size))[applied_indexer] # type: ignore[assignment] + indexer = range(*old_indexer.indices(size))[applied_indexer] else: indexer = _expand_slice(old_indexer, size)[applied_indexer] + elif isinstance(old_indexer, MultipleSlices): + if isinstance(applied_indexer, slice): + indexer = MultipleSlices.from_iterable( + slice_slice(s, applied_indexer, size) for s in old_indexer.slices + ).merge_slices() + elif isinstance(applied_indexer, MultipleSlices): + new_slices = ( + slice_slice(slice_a, slice_b, size) + for slice_a in old_indexer.slices + for slice_b in applied_indexer.slices + ) + + indexer = MultipleSlices.from_iterable( + s for s in new_slices if len(range(*s.indices(size))) > 0 + ) + elif isinstance(applied_indexer, integer_types): + [selected_slice] = [ + s + for s in old_indexer.slices + if s.start >= applied_indexer and s.stop < applied_indexer + ] + indexer = range(*selected_slice.indices(size))[applied_indexer] + else: + parts = [ + _expand_slice(s, size)[applied_indexer] for s in old_indexer.slices + ] + indexer = np.concatenate(parts) + elif isinstance(applied_indexer, MultipleSlices): + old_indexer = cast(np.ndarray, old_indexer) + parts = [old_indexer[s] for s in applied_indexer.slices] + indexer = np.concatenate(parts) else: + old_indexer = cast(np.ndarray, old_indexer) indexer = old_indexer[applied_indexer] return indexer @@ -355,6 +403,10 @@ def as_integer_slice(value: slice) -> slice: return slice(start, stop, step) +def as_integer_multi_slice(value: MultipleSlices) -> MultipleSlices: + return MultipleSlices.from_iterable(as_integer_slice(s) for s in value.slices) + + class IndexCallable: """Provide getitem and setitem syntax for callable objects.""" @@ -406,6 +458,11 @@ def __init__(self, key: tuple[int | np.integer | slice, ...]): super().__init__(tuple(new_key)) +outer_indexer_key_type = ( + int | np.integer | slice | np.ndarray[Any, np.dtype[np.generic]] +) + + class OuterIndexer(ExplicitIndexer): """Tuple for outer/orthogonal indexing. @@ -417,12 +474,7 @@ class OuterIndexer(ExplicitIndexer): __slots__ = () - def __init__( - self, - key: tuple[ - int | np.integer | slice | np.ndarray[Any, np.dtype[np.generic]], ... - ], - ): + def __init__(self, key: tuple[outer_indexer_key_type | MultipleSlices, ...]): if not isinstance(key, tuple): raise TypeError(f"key must be a tuple: {key!r}") @@ -432,6 +484,8 @@ def __init__( k = int(k) elif isinstance(k, slice): k = as_integer_slice(k) + elif isinstance(k, MultipleSlices): + k = as_integer_multi_slice(k) elif is_duck_array(k): if not np.issubdtype(k.dtype, np.integer): raise TypeError( @@ -591,6 +645,81 @@ def __getitem__(self, key: Any): return result +class MultipleSlices: + __slots__ = ("_slices",) + + def __init__(self, *slices: slice): + if any(not isinstance(s, slice) for s in slices): + raise ValueError("Can only wrap slice objects.") + + if any(is_full_slice(s) for s in slices) and len(slices) > 1: + raise ValueError("Full slices can only be wrapped on their own.") + + self._slices = list(slices) + + def __repr__(self): + return f"MultipleSlices({', '.join(repr(s) for s in self.slices)})" + + @classmethod + def _construct_direct(cls, slices: list[slice]) -> Self: + instance = cls.__new__(cls) + instance._slices = slices + return instance + + @classmethod + def from_iterable(cls, slices: Iterable[slice]) -> Self: + slices_ = list(slices) + if not slices_: + raise ValueError("need at least one slice object") + + return cls._construct_direct(slices_) + + @property + def slices(self): + return self._slices + + def merge_slices(self) -> Self: + new_slices = list(self._slices[:1]) + previous_index = 0 + for current in self._slices[1:]: + previous = new_slices[previous_index] + + if ( + previous.step == current.step + # `None` is treated as `1` for position slices + or (previous.step in (None, 1) and current.step in (None, 1)) + ) and previous.stop == current.start: + new_slices[previous_index] = slice( + previous.start, current.stop, previous.step + ) + continue + elif ( + current.start is None and current.stop == 0 + ) or current.start == current.stop: + # length 0 slice + continue + + new_slices.append(current) + previous_index += 1 + + return self._construct_direct(new_slices) + + +def decompose_by_multiple_slices( + indexer: tuple[outer_indexer_key_type | MultipleSlices, ...], +) -> tuple[tuple[slice | MultipleSlices, ...], tuple[outer_indexer_key_type, ...]]: + others = tuple( + k if not isinstance(k, MultipleSlices) else slice(None) for k in indexer + ) + multiple_slices = tuple( + k if isinstance(k, MultipleSlices) else slice(None) + for k in indexer + if not isinstance(k, integer_types) + ) + + return multiple_slices, others + + class LazilyIndexedArray(ExplicitlyIndexedNDArrayMixin): """Wrap an array to make basic and outer indexing lazy.""" @@ -623,11 +752,13 @@ def __init__(self, array: Any, key: ExplicitIndexer | None = None): shape += (len(range(*k.indices(size))),) elif isinstance(k, np.ndarray): shape += (k.size,) + elif isinstance(k, MultipleSlices): + shape += (sum(len(range(*s.indices(size))) for s in k.slices),) self._shape = shape def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer: iter_new_key = iter(expanded_indexer(new_key.tuple, self.ndim)) - full_key = [] + full_key: list[OuterIndexerType | MultipleSlices] = [] for size, k in zip(self.array.shape, self.key.tuple, strict=True): if isinstance(k, integer_types): full_key.append(k) @@ -636,7 +767,7 @@ def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer: full_key_tuple = tuple(full_key) if all(isinstance(k, integer_types + (slice,)) for k in full_key_tuple): - return BasicIndexer(full_key_tuple) + return BasicIndexer(cast(tuple[BasicIndexerType, ...], full_key_tuple)) return OuterIndexer(full_key_tuple) @property @@ -681,6 +812,10 @@ def _vindex_set(self, key: VectorizedIndexer, value: Any) -> None: def _oindex_set(self, key: OuterIndexer, value: Any) -> None: full_key = self._updated_key(key) + if any(isinstance(k, MultipleSlices) for k in key.tuple): + raise NotImplementedError( + "Assignment is not supported for a list of slices, yet." + ) self.array.oindex[full_key] = value def __setitem__(self, key: BasicIndexer, value: Any) -> None: @@ -1543,8 +1678,29 @@ def transpose(self, order): return self.array.transpose(order) def _oindex_get(self, indexer: OuterIndexer): - key = _outer_to_numpy_indexer(indexer, self.array.shape) - return self.array[key] + multi_slices, others = decompose_by_multiple_slices(indexer.tuple) + others_ = _outer_to_numpy_indexer(OuterIndexer(others), self.array.shape) + + if all(is_full_slice(s) for s in multi_slices): + return self.array[others_] + + # apply other indexers + if any(not is_full_slice(s) for s in others_): + value = self.array[others_ + (Ellipsis,)] + else: + value = self.array + + # apply the multi-slices + for axis, subkey in enumerate(multi_slices): + if not isinstance(subkey, MultipleSlices): + continue + + parts = [ + value[(slice(None),) * axis + (slice_,)] for slice_ in subkey.slices + ] + value = np.concatenate(parts, axis=axis) + + return value def _vindex_get(self, indexer: VectorizedIndexer): _assert_not_chunked_indexer(indexer.tuple) @@ -1617,12 +1773,25 @@ def __init__(self, array): ) self.array = array + def _oindex_get_impl(self, value, axis, subkey): + if not isinstance(subkey, MultipleSlices): + return value[(slice(None),) * axis + (subkey, Ellipsis)] + + xp = value.__array_namespace__() + return xp.concat( + [ + value[(slice(None),) * axis + (slice_, Ellipsis)] + for slice_ in subkey._slices + ], + axis=axis, + ) + def _oindex_get(self, indexer: OuterIndexer): # manual orthogonal indexing (implemented like DaskIndexingAdapter) key = indexer.tuple value = self.array for axis, subkey in reversed(list(enumerate(key))): - value = value[(slice(None),) * axis + (subkey, Ellipsis)] + value = self._oindex_get_impl(value, axis, subkey) return value def _vindex_get(self, indexer: VectorizedIndexer): @@ -1684,11 +1853,21 @@ def _oindex_get(self, indexer: OuterIndexer): key = indexer.tuple try: return self.array[key] - except NotImplementedError: + except (NotImplementedError, TypeError): # manual orthogonal indexing value = self.array for axis, subkey in reversed(list(enumerate(key))): - value = value[(slice(None),) * axis + (subkey,)] + if not isinstance(subkey, MultipleSlices): + value = value[(slice(None),) * axis + (subkey, Ellipsis)] + continue + + value = np.concatenate( + [ + value[(slice(None),) * axis + (slice_, Ellipsis)] + for slice_ in subkey._slices + ], + axis=axis, + ) return value def _vindex_get(self, indexer: VectorizedIndexer): @@ -1868,7 +2047,13 @@ def _index_get( return getattr(indexable, func_name)(indexer) # otherwise index the pandas index then re-wrap or convert the result - result = self.array[key] + if not isinstance(key, MultipleSlices): + result = self.array[key] + else: + result = None + for s in key.slices: + subset = self.array[s] + result = result.union(subset) if result is not None else subset if isinstance(result, pd.Index): return type(self)(result, dtype=self.dtype) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index bcc2ca4e460..77a59a7db40 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -24,6 +24,7 @@ from xarray.core.indexing import ( BasicIndexer, CoordinateTransformIndexingAdapter, + MultipleSlices, OuterIndexer, PandasIndexingAdapter, VectorizedIndexer, @@ -708,7 +709,9 @@ def _broadcast_indexes_outer(self, key): for k in key: if isinstance(k, Variable): k = k.data - if not isinstance(k, BASIC_INDEXING_TYPES): + if not isinstance(k, BASIC_INDEXING_TYPES) and not isinstance( + k, MultipleSlices + ): if not is_duck_array(k): k = np.asarray(k) if k.size == 0: diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 6dd75b58c6a..07af0dd12a9 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -274,6 +274,70 @@ def test_read_only_view(self) -> None: arr.loc[0, 0, 0] = 999 +class TestMultipleSlices: + def test_init(self): + slices = [slice(None, 2), slice(3, None)] + actual = indexing.MultipleSlices(*slices) + + assert isinstance(actual, indexing.MultipleSlices) + assert actual._slices == slices and actual._slices is not slices + + def test_init_error(self): + slices = [1, slice(2, 3), "a"] + with pytest.raises(ValueError, match="slice objects"): + indexing.MultipleSlices(*slices) + + with pytest.raises(ValueError, match="Full slices"): + indexing.MultipleSlices(slice(None, 2), slice(None)) + + def test_construct_direct(self): + slices = [slice(None, 2), slice(3, None)] + actual = indexing.MultipleSlices._construct_direct(slices) + + assert isinstance(actual, indexing.MultipleSlices) + assert actual._slices is slices + + @pytest.mark.parametrize( + ["iterable", "expected"], + ( + ((slice(i, j) for i, j in [(0, 3), (4, 6)]), [slice(0, 3), slice(4, 6)]), + ( + (slice(None, 2), slice(3, 4), slice(4, 5)), + [slice(None, 2), slice(3, 4), slice(4, 5)], + ), + ), + ) + def test_from_iterable(self, iterable, expected): + actual = indexing.MultipleSlices.from_iterable(iterable) + assert isinstance(actual, indexing.MultipleSlices) + assert actual._slices == expected + + @pytest.mark.parametrize( + ["slices", "expected_slices"], + ( + ([slice(None, 3), slice(3, None)], [slice(None)]), + ([slice(None, 2), slice(2, 4), slice(4, 10)], [slice(None, 10)]), + ( + [slice(None, 2, 1), slice(2, 6, 2), slice(6, 10, 2)], + [slice(None, 2, 1), slice(2, 10, 2)], + ), + ( + [slice(None, 2), slice(2, 5, 1), slice(5, None, 2)], + [slice(None, 5), slice(5, None, 2)], + ), + ( + [slice(None, 2), slice(0), slice(2, 5)], + [slice(None, 5)], + ), + ), + ) + def test_merge_slices(self, slices, expected_slices): + multi_slice = indexing.MultipleSlices.from_iterable(slices) + actual = multi_slice.merge_slices() + + assert actual._slices == expected_slices + + class TestLazyArray: def test_slice_slice(self) -> None: arr = ReturnItem() @@ -313,7 +377,17 @@ def test_lazily_indexed_array(self) -> None: v_lazy = Variable(["i", "j", "k"], lazy) arr = ReturnItem() # test orthogonally applied indexers - indexers = [arr[:], 0, -2, arr[:3], [0, 1, 2, 3], [0], np.arange(10) < 5] + indexers = [ + arr[:], + 0, + -2, + arr[:3], + [0, 1, 2, 3], + [0], + np.arange(10) < 5, + indexing.MultipleSlices(slice(0, 3), slice(5, 7)), + indexing.MultipleSlices(slice(None, 5), slice(7, None, 2)), + ] for i in indexers: for j in indexers: for k in indexers: @@ -974,6 +1048,27 @@ def test_indexing_1d_object_array() -> None: assert [actual.data.item()] == [expected.data.item()] +@pytest.mark.parametrize( + "key", + ( + indexing.MultipleSlices(slice(1, 3), slice(7, 4, -1)), + indexing.MultipleSlices(slice(None, 2), slice(5, None)), + ), +) +def test_indexing_index_multi_slice(key) -> None: + indexer = indexing.OuterIndexer((key,)) + x = np.arange(20) + + pd_adapter = indexing.PandasIndexingAdapter(pd.Index(x)) + np_adapter = indexing.NumpyIndexingAdapter(x) + + actual = pd_adapter.oindex[indexer] + # indexes are sorted + expected = np.sort(np_adapter.oindex[indexer]) + + assert_array_equal(actual, expected) + + @requires_dask def test_indexing_dask_array() -> None: import dask.array @@ -1055,6 +1150,28 @@ def test_advanced_indexing_dask_array() -> None: assert_identical(expected, actual) +@requires_dask +@pytest.mark.parametrize( + "key", + ( + indexing.MultipleSlices(slice(1, 3), slice(7, 4, -1)), + indexing.MultipleSlices(slice(None, 2), slice(5, None)), + ), +) +def test_indexing_dask_multi_slice(key) -> None: + da = DataArray( + np.ones(10 * 3 * 3).reshape((10, 3, 3)), + dims=("time", "x", "y"), + ) + chunked = da.chunk(dict(time=-1, x=1, y=1)) + + with raise_if_dask_computes(): + actual = chunked.isel(time=key) + + expected = da.isel(time=key) + assert_identical(actual, expected) + + def test_backend_indexing_non_numpy() -> None: """This model indexing of a Zarr store that reads to GPU memory.""" array = DuckArrayWrapper(np.array([1, 2, 3]))