Skip to content

Commit 8846d3b

Browse files
committed
Implement dpnp.put_along_axis through dpctl.tensor
1 parent 6ba840a commit 8846d3b

File tree

2 files changed

+56
-73
lines changed

2 files changed

+56
-73
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 27 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -84,57 +84,6 @@
8484
]
8585

8686

87-
def _build_along_axis_index(a, ind, axis):
88-
"""
89-
Build a fancy index used by a family of `_along_axis` functions.
90-
91-
The fancy index consists of orthogonal arranges, with the
92-
requested index inserted at the right location.
93-
94-
The resulting index is going to be used inside `dpnp.put_along_axis`
95-
and `dpnp.take_along_axis` implementations.
96-
97-
"""
98-
99-
if not dpnp.issubdtype(ind.dtype, dpnp.integer):
100-
raise IndexError("`indices` must be an integer array")
101-
102-
# normalize array shape and input axis
103-
if axis is None:
104-
a_shape = (a.size,)
105-
axis = 0
106-
else:
107-
a_shape = a.shape
108-
axis = normalize_axis_index(axis, a.ndim)
109-
110-
if len(a_shape) != ind.ndim:
111-
raise ValueError(
112-
"`indices` and `a` must have the same number of dimensions"
113-
)
114-
115-
# compute dimensions to iterate over
116-
dest_dims = list(range(axis)) + [None] + list(range(axis + 1, ind.ndim))
117-
shape_ones = (1,) * ind.ndim
118-
119-
# build the index
120-
fancy_index = []
121-
for dim, n in zip(dest_dims, a_shape):
122-
if dim is None:
123-
fancy_index.append(ind)
124-
else:
125-
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :]
126-
fancy_index.append(
127-
dpnp.arange(
128-
n,
129-
dtype=ind.dtype,
130-
usm_type=ind.usm_type,
131-
sycl_queue=ind.sycl_queue,
132-
).reshape(ind_shape)
133-
)
134-
135-
return tuple(fancy_index)
136-
137-
13887
def _ravel_multi_index_checks(multi_index, dims, order):
13988
dpnp.check_supported_arrays_type(*multi_index)
14089
ndim = len(dims)
@@ -1371,7 +1320,7 @@ def put(a, ind, v, /, *, axis=None, mode="wrap"):
13711320
in_usm_a[:] = dpt.reshape(usm_a, in_usm_a.shape, copy=False)
13721321

13731322

1374-
def put_along_axis(a, ind, values, axis):
1323+
def put_along_axis(a, ind, values, axis, mode="wrap"):
13751324
"""
13761325
Put values into the destination array by matching 1d index and data slices.
13771326
@@ -1395,9 +1344,16 @@ def put_along_axis(a, ind, values, axis):
13951344
values : {scalar, array_like}, (Ni..., J, Nk...)
13961345
Values to insert at those indices. Its shape and dimension are
13971346
broadcast to match that of `ind`.
1398-
axis : int
1347+
axis : {None, int}
13991348
The axis to take 1d slices along. If axis is ``None``, the destination
14001349
array is treated as if a flattened 1d view had been created of it.
1350+
mode : {"wrap", "clip"}, optional
1351+
Specifies how out-of-bounds indices will be handled. Possible values
1352+
are:
1353+
- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
1354+
negative indices.
1355+
- ``"clip"``: clips indices to (``0 <= i < n``).
1356+
Default: ``"wrap"``.
14011357
14021358
See Also
14031359
--------
@@ -1426,12 +1382,26 @@ def put_along_axis(a, ind, values, axis):
14261382
14271383
"""
14281384

1429-
dpnp.check_supported_arrays_type(a, ind)
1430-
14311385
if axis is None:
1432-
a = a.ravel()
1386+
dpnp.check_supported_arrays_type(ind)
1387+
if ind.ndim != 1:
1388+
raise ValueError(
1389+
"when axis=None, `ind` must have a single dimension."
1390+
)
1391+
1392+
a = dpnp.ravel(a)
1393+
axis = 0
1394+
1395+
usm_a = dpnp.get_usm_ndarray(a)
1396+
usm_ind = dpnp.get_usm_ndarray(ind)
1397+
if dpnp.is_supported_array_type(values):
1398+
usm_vals = dpnp.get_usm_ndarray(values)
1399+
else:
1400+
usm_vals = dpt.asarray(
1401+
values, usm_type=a.usm_type, sycl_queue=a.sycl_queue
1402+
)
14331403

1434-
a[_build_along_axis_index(a, ind, axis)] = values
1404+
dpt.put_along_axis(usm_a, usm_ind, usm_vals, axis=axis, mode=mode)
14351405

14361406

14371407
def putmask(x1, mask, values):

tests/test_indexing.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -594,38 +594,51 @@ def test_replace_max(self, arr_dt, axis):
594594
],
595595
)
596596
def test_values(self, arr_dt, idx_dt, ndim, values):
597-
np_a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim)
598-
np_ai = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape(
597+
a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim)
598+
ind = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape(
599599
(1,) * (ndim - 1) + (4,)
600600
)
601-
602-
dp_a = dpnp.array(np_a, dtype=arr_dt)
603-
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
601+
ia, iind = dpnp.array(a), dpnp.array(ind)
604602

605603
for axis in range(ndim):
606-
numpy.put_along_axis(np_a, np_ai, values, axis)
607-
dpnp.put_along_axis(dp_a, dp_ai, values, axis)
608-
assert_array_equal(np_a, dp_a)
604+
numpy.put_along_axis(a, ind, values, axis)
605+
dpnp.put_along_axis(ia, iind, values, axis)
606+
assert_array_equal(ia, a)
609607

610608
@pytest.mark.parametrize("xp", [numpy, dpnp])
611609
@pytest.mark.parametrize("dt", [bool, numpy.float32])
612610
def test_invalid_indices_dtype(self, xp, dt):
613611
a = xp.ones((10, 10))
614-
ind = xp.ones(10, dtype=dt)
612+
ind = xp.ones_like(a, dtype=dt)
615613
assert_raises(IndexError, xp.put_along_axis, a, ind, 7, axis=1)
616614

617615
@pytest.mark.parametrize("arr_dt", get_all_dtypes())
618616
@pytest.mark.parametrize("idx_dt", get_integer_dtypes())
619617
def test_broadcast(self, arr_dt, idx_dt):
620-
np_a = numpy.ones((3, 4, 1), dtype=arr_dt)
621-
np_ai = numpy.arange(10, dtype=idx_dt).reshape((1, 2, 5)) % 4
618+
a = numpy.ones((3, 4, 1), dtype=arr_dt)
619+
ind = numpy.arange(10, dtype=idx_dt).reshape((1, 2, 5)) % 4
620+
ia, iind = dpnp.array(a), dpnp.array(ind)
622621

623-
dp_a = dpnp.array(np_a, dtype=arr_dt)
624-
dp_ai = dpnp.array(np_ai, dtype=idx_dt)
622+
numpy.put_along_axis(a, ind, 20, axis=1)
623+
dpnp.put_along_axis(ia, iind, 20, axis=1)
624+
assert_array_equal(ia, a)
625625

626-
numpy.put_along_axis(np_a, np_ai, 20, axis=1)
627-
dpnp.put_along_axis(dp_a, dp_ai, 20, axis=1)
628-
assert_array_equal(np_a, dp_a)
626+
def test_mode_wrap(self):
627+
a = numpy.array([-2, -1, 0, 1, 2])
628+
ind = numpy.array([-2, 2, -5, 4])
629+
ia, iind = dpnp.array(a), dpnp.array(ind)
630+
631+
dpnp.put_along_axis(ia, iind, 3, axis=0, mode="wrap")
632+
numpy.put_along_axis(a, ind, 3, axis=0)
633+
assert_array_equal(ia, a)
634+
635+
def test_mode_clip(self):
636+
a = dpnp.array([-2, -1, 0, 1, 2])
637+
ind = dpnp.array([-2, 2, -5, 4])
638+
639+
# numpy does not support keyword `mode`
640+
dpnp.put_along_axis(a, ind, 4, axis=0, mode="clip")
641+
assert (a == dpnp.array([4, -1, 4, 1, 4])).all()
629642

630643

631644
class TestTake:

0 commit comments

Comments
 (0)