diff --git a/docs/doc_sources/api_reference/dpctl/tensor.indexing_functions.rst b/docs/doc_sources/api_reference/dpctl/tensor.indexing_functions.rst index f4a35d40db..09287ba49f 100644 --- a/docs/doc_sources/api_reference/dpctl/tensor.indexing_functions.rst +++ b/docs/doc_sources/api_reference/dpctl/tensor.indexing_functions.rst @@ -15,3 +15,4 @@ by either integral arrays of indices or boolean mask arrays. place put take + take_along_axis diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index dff75b9c2c..579b56d3a3 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -60,7 +60,14 @@ ) from dpctl.tensor._device import Device from dpctl.tensor._dlpack import from_dlpack -from dpctl.tensor._indexing_functions import extract, nonzero, place, put, take +from dpctl.tensor._indexing_functions import ( + extract, + nonzero, + place, + put, + take, + take_along_axis, +) from dpctl.tensor._linear_algebra_functions import ( matmul, matrix_transpose, @@ -376,4 +383,5 @@ "nextafter", "diff", "count_nonzero", + "take_along_axis", ] diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index cbe86ad06b..dc5e7268a4 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -795,13 +795,18 @@ def _nonzero_impl(ary): return res -def _take_multi_index(ary, inds, p): +def _take_multi_index(ary, inds, p, mode=0): if not isinstance(ary, dpt.usm_ndarray): raise TypeError( f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}" ) ary_nd = ary.ndim p = normalize_axis_index(operator.index(p), ary_nd) + mode = operator.index(mode) + if mode not in [0, 1]: + raise ValueError( + "Invalid value for mode keyword, only 0 or 1 is supported" + ) queues_ = [ ary.sycl_queue, ] @@ -860,7 +865,7 @@ def _take_multi_index(ary, inds, p): ind=inds, dst=res, axis_start=p, - mode=0, + mode=mode, sycl_queue=exec_q, depends=dep_ev, ) diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index 54f75b45a4..a0ac2bb6cb 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -21,7 +21,7 @@ import dpctl.tensor._tensor_impl as ti import dpctl.utils -from ._copy_utils import _extract_impl, _nonzero_impl +from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index from ._numpy_helper import normalize_axis_index @@ -423,3 +423,82 @@ def nonzero(arr): if arr.ndim == 0: raise ValueError("Array of positive rank is expected") return _nonzero_impl(arr) + + +def _range(sh_i, i, nd, q, usm_t, dt): + ind = dpt.arange(sh_i, dtype=dt, usm_type=usm_t, sycl_queue=q) + ind.shape = tuple(sh_i if i == j else 1 for j in range(nd)) + return ind + + +def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"): + """ + Returns elements from an array at the one-dimensional indices specified + by ``indices`` along a provided ``axis``. + + Args: + x (usm_ndarray): + input array. Must be compatible with ``indices``, except for the + axis (dimension) specified by ``axis``. + indices (usm_ndarray): + array indices. Must have the same rank (i.e., number of dimensions) + as ``x``. + axis: int + axis along which to select values. If ``axis`` is negative, the + function determines the axis along which to select values by + counting from the last dimension. Default: ``-1``. + mode (str, optional): + How out-of-bounds indices will be handled. Possible values + are: + + - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps + negative indices. + - ``"clip"``: clips indices to (``0 <= i < n``). + + Default: ``"wrap"``. + + Returns: + usm_ndarray: + an array having the same data type as ``x``. The returned array has + the same rank (i.e., number of dimensions) as ``x`` and a shape + determined according to :ref:`broadcasting`, except for the axis + (dimension) specified by ``axis`` whose size must equal the size + of the corresponding axis (dimension) in ``indices``. + + Note: + Treatment of the out-of-bound indices in ``indices`` array is controlled + by the value of ``mode`` keyword. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + if not isinstance(indices, dpt.usm_ndarray): + raise TypeError( + f"Expected dpctl.tensor.usm_ndarray, got {type(indices)}" + ) + x_nd = x.ndim + if x_nd != indices.ndim: + raise ValueError( + "Number of dimensions in the first and the second " + "argument arrays must be equal" + ) + pp = normalize_axis_index(operator.index(axis), x_nd) + out_usm_type = dpctl.utils.get_coerced_usm_type( + (x.usm_type, indices.usm_type) + ) + exec_q = dpctl.utils.get_execution_queue((x.sycl_queue, indices.sycl_queue)) + if exec_q is None: + raise dpctl.utils.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments. " + ) + mode_i = _get_indexing_mode(mode) + indexes_dt = ti.default_device_index_type(exec_q.sycl_device) + _ind = tuple( + ( + indices + if i == pp + else _range(x.shape[i], i, x_nd, exec_q, out_usm_type, indexes_dt) + ) + for i in range(x_nd) + ) + return _take_multi_index(x, _ind, 0, mode=mode_i) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 5fb7da4543..9fb2f04946 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1535,5 +1535,151 @@ def test_advanced_integer_indexing_cast_indices(): inds1 = dpt.astype(inds0, "u4") inds2 = dpt.astype(inds0, "u8") x = dpt.ones((3, 4, 5, 6), dtype="i4") + # test getitem with pytest.raises(ValueError): x[inds0, inds1, inds2, ...] + # test setitem + with pytest.raises(ValueError): + x[inds0, inds1, inds2, ...] = 1 + + +def test_take_along_axis(): + get_queue_or_skip() + + n0, n1, n2 = 3, 5, 7 + x = dpt.reshape(dpt.arange(n0 * n1 * n2), (n0, n1, n2)) + ind_dt = dpt.__array_namespace_info__().default_dtypes( + device=x.sycl_device + )["indexing"] + ind0 = dpt.ones((1, n1, n2), dtype=ind_dt) + ind1 = dpt.ones((n0, 1, n2), dtype=ind_dt) + ind2 = dpt.ones((n0, n1, 1), dtype=ind_dt) + + y0 = dpt.take_along_axis(x, ind0, axis=0) + assert y0.shape == ind0.shape + y1 = dpt.take_along_axis(x, ind1, axis=1) + assert y1.shape == ind1.shape + y2 = dpt.take_along_axis(x, ind2, axis=2) + assert y2.shape == ind2.shape + + +def test_take_along_axis_validation(): + # type check on the first argument + with pytest.raises(TypeError): + dpt.take_along_axis(tuple(), list()) + get_queue_or_skip() + n1, n2 = 2, 5 + x = dpt.ones(n1 * n2) + # type check on the second argument + with pytest.raises(TypeError): + dpt.take_along_axis(x, list()) + x_dev = x.sycl_device + info_ = dpt.__array_namespace_info__() + def_dtypes = info_.default_dtypes(device=x_dev) + ind_dt = def_dtypes["indexing"] + ind = dpt.zeros(1, dtype=ind_dt) + # axis valudation + with pytest.raises(ValueError): + dpt.take_along_axis(x, ind, axis=1) + # mode validation + with pytest.raises(ValueError): + dpt.take_along_axis(x, ind, axis=0, mode="invalid") + # same array-ranks validation + with pytest.raises(ValueError): + dpt.take_along_axis(dpt.reshape(x, (n1, n2)), ind) + # check compute-follows-data + q2 = dpctl.SyclQueue(x_dev, property="enable_profiling") + ind2 = dpt.zeros(1, dtype=ind_dt, sycl_queue=q2) + with pytest.raises(ExecutionPlacementError): + dpt.take_along_axis(x, ind2) + + +def check__extract_impl_validation(fn): + x = dpt.ones(10) + ind = dpt.ones(10, dtype="?") + with pytest.raises(TypeError): + fn(list(), ind) + with pytest.raises(TypeError): + fn(x, list()) + q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling") + ind2 = dpt.ones(10, dtype="?", sycl_queue=q2) + with pytest.raises(ExecutionPlacementError): + fn(x, ind2) + with pytest.raises(ValueError): + fn(x, ind, 1) + + +def check__nonzero_impl_validation(fn): + with pytest.raises(TypeError): + fn(list()) + + +def check__take_multi_index(fn): + x = dpt.ones(10) + x_dev = x.sycl_device + info_ = dpt.__array_namespace_info__() + def_dtypes = info_.default_dtypes(device=x_dev) + ind_dt = def_dtypes["indexing"] + ind = dpt.arange(10, dtype=ind_dt) + with pytest.raises(TypeError): + fn(list(), tuple(), 1) + with pytest.raises(ValueError): + fn(x, (ind,), 0, mode=2) + with pytest.raises(ValueError): + fn(x, (None,), 1) + with pytest.raises(IndexError): + fn(x, (x,), 1) + q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling") + ind2 = dpt.arange(10, dtype=ind_dt, sycl_queue=q2) + with pytest.raises(ExecutionPlacementError): + fn(x, (ind2,), 0) + m = dpt.ones((10, 10)) + ind_1 = dpt.arange(10, dtype="i8") + ind_2 = dpt.arange(10, dtype="u8") + with pytest.raises(ValueError): + fn(m, (ind_1, ind_2), 0) + + +def check__place_impl_validation(fn): + with pytest.raises(TypeError): + fn(list(), list(), list()) + x = dpt.ones(10) + with pytest.raises(TypeError): + fn(x, list(), list()) + q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling") + mask2 = dpt.ones(10, dtype="?", sycl_queue=q2) + with pytest.raises(ExecutionPlacementError): + fn(x, mask2, 1) + x2 = dpt.ones((5, 5)) + mask2 = dpt.ones((5, 5), dtype="?") + with pytest.raises(ValueError): + fn(x2, mask2, x2, axis=1) + + +def check__put_multi_index_validation(fn): + with pytest.raises(TypeError): + fn(list(), list(), 0, list()) + x = dpt.ones(10) + inds = dpt.arange(10, dtype="i8") + vals = dpt.zeros(10) + # test inds which is not a tuple/list + fn(x, inds, 0, vals) + x2 = dpt.ones((5, 5)) + ind1 = dpt.arange(5, dtype="i8") + ind2 = dpt.arange(5, dtype="u8") + with pytest.raises(ValueError): + fn(x2, (ind1, ind2), 0, x2) + with pytest.raises(TypeError): + fn(x2, (ind1, list()), 0, x2) + + +def test__copy_utils(): + import dpctl.tensor._copy_utils as cu + + get_queue_or_skip() + + check__extract_impl_validation(cu._extract_impl) + check__nonzero_impl_validation(cu._nonzero_impl) + check__take_multi_index(cu._take_multi_index) + check__place_impl_validation(cu._place_impl) + check__put_multi_index_validation(cu._put_multi_index) diff --git a/dpctl/tests/test_usm_ndarray_sorting.py b/dpctl/tests/test_usm_ndarray_sorting.py index fa73dcfdfa..088780d103 100644 --- a/dpctl/tests/test_usm_ndarray_sorting.py +++ b/dpctl/tests/test_usm_ndarray_sorting.py @@ -177,12 +177,24 @@ def test_argsort_axis0(): x = dpt.reshape(xf, (n, m)) idx = dpt.argsort(x, axis=0) - conseq_idx = dpt.arange(m, dtype=idx.dtype) - s = x[idx, conseq_idx[dpt.newaxis, :]] + s = dpt.take_along_axis(x, idx, axis=0) assert dpt.all(s[:-1, :] <= s[1:, :]) +def test_argsort_axis1(): + get_queue_or_skip() + + n, m = 200, 30 + xf = dpt.arange(n * m, 0, step=-1, dtype="i4") + x = dpt.reshape(xf, (n, m)) + idx = dpt.argsort(x, axis=1) + + s = dpt.take_along_axis(x, idx, axis=1) + + assert dpt.all(s[:, :-1] <= s[:, 1:]) + + def test_sort_strided(): get_queue_or_skip() @@ -199,8 +211,9 @@ def test_argsort_strided(): x_orig = dpt.arange(100, dtype="i4") x_flipped = dpt.flip(x_orig, axis=0) idx = dpt.argsort(x_flipped) + s = dpt.take_along_axis(x_flipped, idx, axis=0) - assert dpt.all(x_flipped[idx] == x_orig) + assert dpt.all(s == x_orig) def test_sort_0d_array():