diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index d616ebfcc9cb..f1d00842e5aa 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -84,57 +84,6 @@ ] -def _build_along_axis_index(a, ind, axis): - """ - Build a fancy index used by a family of `_along_axis` functions. - - The fancy index consists of orthogonal arranges, with the - requested index inserted at the right location. - - The resulting index is going to be used inside `dpnp.put_along_axis` - and `dpnp.take_along_axis` implementations. - - """ - - if not dpnp.issubdtype(ind.dtype, dpnp.integer): - raise IndexError("`indices` must be an integer array") - - # normalize array shape and input axis - if axis is None: - a_shape = (a.size,) - axis = 0 - else: - a_shape = a.shape - axis = normalize_axis_index(axis, a.ndim) - - if len(a_shape) != ind.ndim: - raise ValueError( - "`indices` and `a` must have the same number of dimensions" - ) - - # compute dimensions to iterate over - dest_dims = list(range(axis)) + [None] + list(range(axis + 1, ind.ndim)) - shape_ones = (1,) * ind.ndim - - # build the index - fancy_index = [] - for dim, n in zip(dest_dims, a_shape): - if dim is None: - fancy_index.append(ind) - else: - ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim + 1 :] - fancy_index.append( - dpnp.arange( - n, - dtype=ind.dtype, - usm_type=ind.usm_type, - sycl_queue=ind.sycl_queue, - ).reshape(ind_shape) - ) - - return tuple(fancy_index) - - def _ravel_multi_index_checks(multi_index, dims, order): dpnp.check_supported_arrays_type(*multi_index) ndim = len(dims) @@ -1371,7 +1320,7 @@ def put(a, ind, v, /, *, axis=None, mode="wrap"): in_usm_a[:] = dpt.reshape(usm_a, in_usm_a.shape, copy=False) -def put_along_axis(a, ind, values, axis): +def put_along_axis(a, ind, values, axis, mode="wrap"): """ Put values into the destination array by matching 1d index and data slices. @@ -1395,9 +1344,18 @@ def put_along_axis(a, ind, values, axis): values : {scalar, array_like}, (Ni..., J, Nk...) Values to insert at those indices. Its shape and dimension are broadcast to match that of `ind`. - axis : int + axis : {None, int} The axis to take 1d slices along. If axis is ``None``, the destination array is treated as if a flattened 1d view had been created of it. + mode : {"wrap", "clip"}, optional + Specifies how out-of-bounds indices will be handled. Possible values + are: + + - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps + negative indices. + - ``"clip"``: clips indices to (``0 <= i < n``). + + Default: ``"wrap"``. See Also -------- @@ -1426,12 +1384,26 @@ def put_along_axis(a, ind, values, axis): """ - dpnp.check_supported_arrays_type(a, ind) - if axis is None: - a = a.ravel() + dpnp.check_supported_arrays_type(ind) + if ind.ndim != 1: + raise ValueError( + "when axis=None, `ind` must have a single dimension." + ) + + a = dpnp.ravel(a) + axis = 0 + + usm_a = dpnp.get_usm_ndarray(a) + usm_ind = dpnp.get_usm_ndarray(ind) + if dpnp.is_supported_array_type(values): + usm_vals = dpnp.get_usm_ndarray(values) + else: + usm_vals = dpt.asarray( + values, usm_type=a.usm_type, sycl_queue=a.sycl_queue + ) - a[_build_along_axis_index(a, ind, axis)] = values + dpt.put_along_axis(usm_a, usm_ind, usm_vals, axis=axis, mode=mode) def putmask(x1, mask, values): diff --git a/tests/test_indexing.py b/tests/test_indexing.py index ecefc34773ec..0d6e5b4e390d 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -594,38 +594,57 @@ def test_replace_max(self, arr_dt, axis): ], ) def test_values(self, arr_dt, idx_dt, ndim, values): - np_a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim) - np_ai = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape( + a = numpy.arange(4**ndim, dtype=arr_dt).reshape((4,) * ndim) + ind = numpy.array([3, 0, 2, 1], dtype=idx_dt).reshape( (1,) * (ndim - 1) + (4,) ) - - dp_a = dpnp.array(np_a, dtype=arr_dt) - dp_ai = dpnp.array(np_ai, dtype=idx_dt) + ia, iind = dpnp.array(a), dpnp.array(ind) for axis in range(ndim): - numpy.put_along_axis(np_a, np_ai, values, axis) - dpnp.put_along_axis(dp_a, dp_ai, values, axis) - assert_array_equal(np_a, dp_a) + numpy.put_along_axis(a, ind, values, axis) + dpnp.put_along_axis(ia, iind, values, axis) + assert_array_equal(ia, a) @pytest.mark.parametrize("xp", [numpy, dpnp]) @pytest.mark.parametrize("dt", [bool, numpy.float32]) def test_invalid_indices_dtype(self, xp, dt): a = xp.ones((10, 10)) - ind = xp.ones(10, dtype=dt) + ind = xp.ones_like(a, dtype=dt) assert_raises(IndexError, xp.put_along_axis, a, ind, 7, axis=1) @pytest.mark.parametrize("arr_dt", get_all_dtypes()) @pytest.mark.parametrize("idx_dt", get_integer_dtypes()) def test_broadcast(self, arr_dt, idx_dt): - np_a = numpy.ones((3, 4, 1), dtype=arr_dt) - np_ai = numpy.arange(10, dtype=idx_dt).reshape((1, 2, 5)) % 4 + a = numpy.ones((3, 4, 1), dtype=arr_dt) + ind = numpy.arange(10, dtype=idx_dt).reshape((1, 2, 5)) % 4 + ia, iind = dpnp.array(a), dpnp.array(ind) + + numpy.put_along_axis(a, ind, 20, axis=1) + dpnp.put_along_axis(ia, iind, 20, axis=1) + assert_array_equal(ia, a) + + def test_mode_wrap(self): + a = numpy.array([-2, -1, 0, 1, 2]) + ind = numpy.array([-2, 2, -5, 4]) + ia, iind = dpnp.array(a), dpnp.array(ind) + + dpnp.put_along_axis(ia, iind, 3, axis=0, mode="wrap") + numpy.put_along_axis(a, ind, 3, axis=0) + assert_array_equal(ia, a) + + def test_mode_clip(self): + a = dpnp.array([-2, -1, 0, 1, 2]) + ind = dpnp.array([-2, 2, -5, 4]) - dp_a = dpnp.array(np_a, dtype=arr_dt) - dp_ai = dpnp.array(np_ai, dtype=idx_dt) + # numpy does not support keyword `mode` + dpnp.put_along_axis(a, ind, 4, axis=0, mode="clip") + assert (a == dpnp.array([4, -1, 4, 1, 4])).all() - numpy.put_along_axis(np_a, np_ai, 20, axis=1) - dpnp.put_along_axis(dp_a, dp_ai, 20, axis=1) - assert_array_equal(np_a, dp_a) + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_indices_ndim_axis_none(self, xp): + a = xp.ones((10, 10)) + ind = xp.ones((10, 2), dtype=xp.intp) + assert_raises(ValueError, xp.put_along_axis, a, ind, -1, axis=None) class TestTake: