diff --git a/dpnp/dpnp_iface_functional.py b/dpnp/dpnp_iface_functional.py index d0c64ca7e7ff..71f2ef220332 100644 --- a/dpnp/dpnp_iface_functional.py +++ b/dpnp/dpnp_iface_functional.py @@ -37,7 +37,6 @@ """ -import numpy from dpctl.tensor._numpy_helper import ( normalize_axis_index, normalize_axis_tuple, @@ -151,8 +150,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): # compute indices for the iteration axes, and append a trailing ellipsis to # prevent 0d arrays decaying to scalars - # TODO: replace with dpnp.ndindex - inds = numpy.ndindex(inarr_view.shape[:-1]) + inds = dpnp.ndindex(inarr_view.shape[:-1]) inds = (ind + (Ellipsis,) for ind in inds) # invoke the function on the first item diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index fa4109acc683..9881df7c4c4a 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -64,6 +64,7 @@ "indices", "ix_", "mask_indices", + "ndindex", "nonzero", "place", "put", @@ -1057,6 +1058,77 @@ def mask_indices( return nonzero(a != 0) +# pylint: disable=invalid-name +# pylint: disable=too-few-public-methods +class ndindex: + """ + An N-dimensional iterator object to index arrays. + + Given the shape of an array, an :obj:`dpnp.ndindex` instance iterates over + the N-dimensional index of the array. At each iteration a tuple of indices + is returned, the last dimension is iterated over first. + + For full documentation refer to :obj:`numpy.ndindex`. + + Parameters + ---------- + shape : ints, or a single tuple of ints + The size of each dimension of the array can be passed as individual + parameters or as the elements of a tuple. + + See Also + -------- + :obj:`dpnp.ndenumerate` : Multidimensional index iterator. + :obj:`dpnp.flatiter` : Flat iterator object to iterate over arrays. + + Examples + -------- + >>> import dpnp as np + + Dimensions as individual arguments + + >>> for index in np.ndindex(3, 2, 1): + ... print(index) + (0, 0, 0) + (0, 1, 0) + (1, 0, 0) + (1, 1, 0) + (2, 0, 0) + (2, 1, 0) + + Same dimensions - but in a tuple ``(3, 2, 1)`` + + >>> for index in np.ndindex((3, 2, 1)): + ... print(index) + (0, 0, 0) + (0, 1, 0) + (1, 0, 0) + (1, 1, 0) + (2, 0, 0) + (2, 1, 0) + + """ + + def __init__(self, *shape): + self.ndindex_ = numpy.ndindex(*shape) + + def __iter__(self): + return self.ndindex_ + + def __next__(self): + """ + Standard iterator method, updates the index and returns the index tuple. + + Returns + ------- + val : tuple of ints + Returns a tuple containing the indices of the current iteration. + + """ + + return self.ndindex_.__next__() + + def nonzero(a): """ Return the indices of the elements that are non-zero. diff --git a/dpnp/dpnp_utils/dpnp_utils_pad.py b/dpnp/dpnp_utils/dpnp_utils_pad.py index cff48a69410c..4f71b9e393d3 100644 --- a/dpnp/dpnp_utils/dpnp_utils_pad.py +++ b/dpnp/dpnp_utils/dpnp_utils_pad.py @@ -642,8 +642,7 @@ def dpnp_pad(array, pad_width, mode="constant", **kwargs): # compute indices for the iteration axes, and append a trailing # ellipsis to prevent 0d arrays decaying to scalars - # TODO: replace with dpnp.ndindex when implemented - inds = numpy.ndindex(view.shape[:-1]) + inds = dpnp.ndindex(view.shape[:-1]) inds = (ind + (Ellipsis,) for ind in inds) for ind in inds: function(view[ind], pad_width[axis], axis, kwargs) diff --git a/dpnp/tests/test_indexing.py b/dpnp/tests/test_indexing.py index 13266470e8c7..a3b7ddf5c349 100644 --- a/dpnp/tests/test_indexing.py +++ b/dpnp/tests/test_indexing.py @@ -335,6 +335,32 @@ def test_ix_error(self, xp, shape): assert_raises(ValueError, xp.ix_, xp.ones(shape)) +@pytest.mark.parametrize( + "shape", [[1, 2, 3], [(1, 2, 3)], [(3,)], [3], [], [()], [0]] +) +class TestNdindex: + def test_basic(self, shape): + result = dpnp.ndindex(*shape) + expected = numpy.ndindex(*shape) + + for x, y in zip(result, expected): + assert x == y + + def test_next(self, shape): + dind = dpnp.ndindex(*shape) + nind = numpy.ndindex(*shape) + + while True: + try: + ditem = next(dind) + except StopIteration: + assert_raises(StopIteration, next, nind) + break # both reach ends + else: + nitem = next(nind) + assert ditem == nitem + + class TestNonzero: @pytest.mark.parametrize("list_val", [[], [0], [1]]) def test_trivial(self, list_val):