Skip to content

Commit 2c4f3b5

Browse files
antonwolfyvtavana
andauthored
Leverage on dpctl.tensor implementation in dpnp.put_along_axis (#2134)
* Implement dpnp.put_along_axis through dpctl.tensor * Increase test coverage * Update dpnp/dpnp_iface_indexing.py Co-authored-by: vtavana <[email protected]> * Add empty lines per review comment --------- Co-authored-by: vtavana <[email protected]>
1 parent d1ff2e7 commit 2c4f3b5

File tree

2 files changed

+64
-73
lines changed

2 files changed

+64
-73
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 29 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -84,57 +84,6 @@
8484
]
8585

8686

87-
def _build_along_axis_index(a, ind, axis):
88-
"""
89-
Build a fancy index used by a family of `_along_axis` functions.
90-
91-
The fancy index consists of orthogonal arranges, with the
92-
requested index inserted at the right location.
93-
94-
The resulting index is going to be used inside `dpnp.put_along_axis`
95-
and `dpnp.take_along_axis` implementations.
96-
97-
"""
98-
99-
if not dpnp.issubdtype(ind.dtype, dpnp.integer):
100-
raise IndexError("`indices` must be an integer array")
101-
102-
# normalize array shape and input axis
103-
if axis is None:
104-
a_shape = (a.size,)
105-
axis = 0
106-
else:
107-
a_shape = a.shape
108-
axis = normalize_axis_index(axis, a.ndim)
109-
110-
if len(a_shape) != ind.ndim:
111-
raise ValueError(
112-
"`indices` and `a` must have the same number of dimensions"
113-
)
114-
115-
# compute dimensions to iterate over
116-
dest_dims = list(range(axis)) + [None] + list(range(axis + 1, ind.ndim))
117-
shape_ones = (1,) * ind.ndim
118-
119-
# build the index
120-
fancy_index = []
121-
for dim, n in zip(dest_dims, a_shape):
122-
if dim is None:
123-
fancy_index.append(ind)
124-
else:
125-
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :]
126-
fancy_index.append(
127-
dpnp.arange(
128-
n,
129-
dtype=ind.dtype,
130-
usm_type=ind.usm_type,
131-
sycl_queue=ind.sycl_queue,
132-
).reshape(ind_shape)
133-
)
134-
135-
return tuple(fancy_index)
136-
137-
13887
def _ravel_multi_index_checks(multi_index, dims, order):
13988
dpnp.check_supported_arrays_type(*multi_index)
14089
ndim = len(dims)
@@ -1371,7 +1320,7 @@ def put(a, ind, v, /, *, axis=None, mode="wrap"):
13711320
in_usm_a[:] = dpt.reshape(usm_a, in_usm_a.shape, copy=False)
13721321

13731322

1374-
def put_along_axis(a, ind, values, axis):
1323+
def put_along_axis(a, ind, values, axis, mode="wrap"):
13751324
"""
13761325
Put values into the destination array by matching 1d index and data slices.
13771326
@@ -1395,9 +1344,18 @@ def put_along_axis(a, ind, values, axis):
13951344
values : {scalar, array_like}, (Ni..., J, Nk...)
13961345
Values to insert at those indices. Its shape and dimension are
13971346
broadcast to match that of `ind`.
1398-
axis : int
1347+
axis : {None, int}
13991348
The axis to take 1d slices along. If axis is ``None``, the destination
14001349
array is treated as if a flattened 1d view had been created of it.
1350+
mode : {"wrap", "clip"}, optional
1351+
Specifies how out-of-bounds indices will be handled. Possible values
1352+
are:
1353+
1354+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
1355+
negative indices.
1356+
- ``"clip"``: clips indices to (``0 <= i < n``).
1357+
1358+
Default: ``"wrap"``.
14011359
14021360
See Also
14031361
--------
@@ -1426,12 +1384,26 @@ def put_along_axis(a, ind, values, axis):
14261384
14271385
"""
14281386

1429-
dpnp.check_supported_arrays_type(a, ind)
1430-
14311387
if axis is None:
1432-
a = a.ravel()
1388+
dpnp.check_supported_arrays_type(ind)
1389+
if ind.ndim != 1:
1390+
raise ValueError(
1391+
"when axis=None, `ind` must have a single dimension."
1392+
)
1393+
1394+
a = dpnp.ravel(a)
1395+
axis = 0
1396+
1397+
usm_a = dpnp.get_usm_ndarray(a)
1398+
usm_ind = dpnp.get_usm_ndarray(ind)
1399+
if dpnp.is_supported_array_type(values):
1400+
usm_vals = dpnp.get_usm_ndarray(values)
1401+
else:
1402+
usm_vals = dpt.asarray(
1403+
values, usm_type=a.usm_type, sycl_queue=a.sycl_queue
1404+
)
14331405

1434-
a[_build_along_axis_index(a, ind, axis)] = values
1406+
dpt.put_along_axis(usm_a, usm_ind, usm_vals, axis=axis, mode=mode)
14351407

14361408

14371409
def putmask(x1, mask, values):

tests/test_indexing.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -594,38 +594,57 @@ def test_replace_max(self, arr_dt, axis):
594594
],
595595
)
596596
def test_values(self, arr_dt, idx_dt, ndim, values):
597-
np_a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim)
598-
np_ai = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape(
597+
a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim)
598+
ind = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape(
599599
(1,) * (ndim - 1) + (4,)
600600
)
601-
602-
dp_a = dpnp.array(np_a, dtype=arr_dt)
603-
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
601+
ia, iind = dpnp.array(a), dpnp.array(ind)
604602

605603
for axis in range(ndim):
606-
numpy.put_along_axis(np_a, np_ai, values, axis)
607-
dpnp.put_along_axis(dp_a, dp_ai, values, axis)
608-
assert_array_equal(np_a, dp_a)
604+
numpy.put_along_axis(a, ind, values, axis)
605+
dpnp.put_along_axis(ia, iind, values, axis)
606+
assert_array_equal(ia, a)
609607

610608
@pytest.mark.parametrize("xp", [numpy, dpnp])
611609
@pytest.mark.parametrize("dt", [bool, numpy.float32])
612610
def test_invalid_indices_dtype(self, xp, dt):
613611
a = xp.ones((10, 10))
614-
ind = xp.ones(10, dtype=dt)
612+
ind = xp.ones_like(a, dtype=dt)
615613
assert_raises(IndexError, xp.put_along_axis, a, ind, 7, axis=1)
616614

617615
@pytest.mark.parametrize("arr_dt", get_all_dtypes())
618616
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
619617
def test_broadcast(self, arr_dt, idx_dt):
620-
np_a = numpy.ones((3, 4, 1), dtype=arr_dt)
621-
np_ai = numpy.arange(10, dtype=idx_dt).reshape((1, 2, 5)) % 4
618+
a = numpy.ones((3, 4, 1), dtype=arr_dt)
619+
ind = numpy.arange(10, dtype=idx_dt).reshape((1, 2, 5)) % 4
620+
ia, iind = dpnp.array(a), dpnp.array(ind)
621+
622+
numpy.put_along_axis(a, ind, 20, axis=1)
623+
dpnp.put_along_axis(ia, iind, 20, axis=1)
624+
assert_array_equal(ia, a)
625+
626+
def test_mode_wrap(self):
627+
a = numpy.array([-2, -1, 0, 1, 2])
628+
ind = numpy.array([-2, 2, -5, 4])
629+
ia, iind = dpnp.array(a), dpnp.array(ind)
630+
631+
dpnp.put_along_axis(ia, iind, 3, axis=0, mode="wrap")
632+
numpy.put_along_axis(a, ind, 3, axis=0)
633+
assert_array_equal(ia, a)
634+
635+
def test_mode_clip(self):
636+
a = dpnp.array([-2, -1, 0, 1, 2])
637+
ind = dpnp.array([-2, 2, -5, 4])
622638

623-
dp_a = dpnp.array(np_a, dtype=arr_dt)
624-
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
639+
# numpy does not support keyword `mode`
640+
dpnp.put_along_axis(a, ind, 4, axis=0, mode="clip")
641+
assert (a == dpnp.array([4, -1, 4, 1, 4])).all()
625642

626-
numpy.put_along_axis(np_a, np_ai, 20, axis=1)
627-
dpnp.put_along_axis(dp_a, dp_ai, 20, axis=1)
628-
assert_array_equal(np_a, dp_a)
643+
@pytest.mark.parametrize("xp", [numpy, dpnp])
644+
def test_indices_ndim_axis_none(self, xp):
645+
a = xp.ones((10, 10))
646+
ind = xp.ones((10, 2), dtype=xp.intp)
647+
assert_raises(ValueError, xp.put_along_axis, a, ind, -1, axis=None)
629648

630649

631650
class TestTake:

0 commit comments

Comments
 (0)