Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions dpnp/dpnp_iface_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
"""


import numpy
from dpctl.tensor._numpy_helper import (
normalize_axis_index,
normalize_axis_tuple,
Expand Down Expand Up @@ -151,8 +150,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):

# compute indices for the iteration axes, and append a trailing ellipsis to
# prevent 0d arrays decaying to scalars
# TODO: replace with dpnp.ndindex
inds = numpy.ndindex(inarr_view.shape[:-1])
inds = dpnp.ndindex(inarr_view.shape[:-1])
inds = (ind + (Ellipsis,) for ind in inds)

# invoke the function on the first item
Expand Down
72 changes: 72 additions & 0 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"indices",
"ix_",
"mask_indices",
"ndindex",
"nonzero",
"place",
"put",
Expand Down Expand Up @@ -1057,6 +1058,77 @@ def mask_indices(
return nonzero(a != 0)


# pylint: disable=invalid-name
# pylint: disable=too-few-public-methods
class ndindex:
"""
An N-dimensional iterator object to index arrays.

Given the shape of an array, an :obj:`dpnp.ndindex` instance iterates over
the N-dimensional index of the array. At each iteration a tuple of indices
is returned, the last dimension is iterated over first.

For full documentation refer to :obj:`numpy.ndindex`.

Parameters
----------
shape : ints, or a single tuple of ints
The size of each dimension of the array can be passed as individual
parameters or as the elements of a tuple.

See Also
--------
:obj:`dpnp.ndenumerate` : Multidimensional index iterator.
:obj:`dpnp.flatiter` : Flat iterator object to iterate over arrays.

Examples
--------
>>> import dpnp as np

Dimensions as individual arguments

>>> for index in np.ndindex(3, 2, 1):
... print(index)
(0, 0, 0)
(0, 1, 0)
(1, 0, 0)
(1, 1, 0)
(2, 0, 0)
(2, 1, 0)

Same dimensions - but in a tuple ``(3, 2, 1)``

>>> for index in np.ndindex((3, 2, 1)):
... print(index)
(0, 0, 0)
(0, 1, 0)
(1, 0, 0)
(1, 1, 0)
(2, 0, 0)
(2, 1, 0)

"""

def __init__(self, *shape):
self.ndindex_ = numpy.ndindex(*shape)

def __iter__(self):
return self.ndindex_

def __next__(self):
"""
Standard iterator method, updates the index and returns the index tuple.

Returns
-------
val : tuple of ints
Returns a tuple containing the indices of the current iteration.

"""

return self.ndindex_.__next__()


def nonzero(a):
"""
Return the indices of the elements that are non-zero.
Expand Down
3 changes: 1 addition & 2 deletions dpnp/dpnp_utils/dpnp_utils_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,8 +642,7 @@ def dpnp_pad(array, pad_width, mode="constant", **kwargs):

# compute indices for the iteration axes, and append a trailing
# ellipsis to prevent 0d arrays decaying to scalars
# TODO: replace with dpnp.ndindex when implemented
inds = numpy.ndindex(view.shape[:-1])
inds = dpnp.ndindex(view.shape[:-1])
inds = (ind + (Ellipsis,) for ind in inds)
for ind in inds:
function(view[ind], pad_width[axis], axis, kwargs)
Expand Down
12 changes: 12 additions & 0 deletions dpnp/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,18 @@ def test_ix_error(self, xp, shape):
assert_raises(ValueError, xp.ix_, xp.ones(shape))


class TestNdindex:
@pytest.mark.parametrize(
"shape", [[1, 2, 3], [(1, 2, 3)], [(3,)], [3], [], [()], [0]]
)
def test_basic(self, shape):
result = dpnp.ndindex(*shape)
expected = numpy.ndindex(*shape)

for x, y in zip(result, expected):
assert x == y


class TestNonzero:
@pytest.mark.parametrize("list_val", [[], [0], [1]])
def test_trivial(self, list_val):
Expand Down
Loading