From 09438c3aee35ed44da17f7120bb404c42b1362bb Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 1 Aug 2024 15:25:13 -0500 Subject: [PATCH 01/10] Implement take_along_axis function per Python Array API The function is planned for Python Array API 2024.12 specification. --- dpctl/tensor/__init__.py | 10 +++- dpctl/tensor/_indexing_functions.py | 76 ++++++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 2 deletions(-) diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index dff75b9c2c..579b56d3a3 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -60,7 +60,14 @@ ) from dpctl.tensor._device import Device from dpctl.tensor._dlpack import from_dlpack -from dpctl.tensor._indexing_functions import extract, nonzero, place, put, take +from dpctl.tensor._indexing_functions import ( + extract, + nonzero, + place, + put, + take, + take_along_axis, +) from dpctl.tensor._linear_algebra_functions import ( matmul, matrix_transpose, @@ -376,4 +383,5 @@ "nextafter", "diff", "count_nonzero", + "take_along_axis", ] diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index 54f75b45a4..5718f2b4dd 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -21,7 +21,7 @@ import dpctl.tensor._tensor_impl as ti import dpctl.utils -from ._copy_utils import _extract_impl, _nonzero_impl +from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index from ._numpy_helper import normalize_axis_index @@ -423,3 +423,77 @@ def nonzero(arr): if arr.ndim == 0: raise ValueError("Array of positive rank is expected") return _nonzero_impl(arr) + + +def _range(sh_i, i, nd, q, usm_t, dt): + ind = dpt.arange(sh_i, dtype=dt, usm_type=usm_t, sycl_queue=q) + ind.shape = tuple(sh_i if i == j else 1 for j in range(nd)) + return ind + + +def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"): + """ + Returns elements from an array at the one-dimensional indices specified + by ``indices`` along a provided ``axis``. + + Args: + x (usm_ndarray): + input array. Must be compatible with ``indices``, except for the + axis (dimension) specified by ``axis``. + indices (usm_ndarray): + array indices. Must have the same rank (i.e., number of dimensions) + as ``x``. + axis: int + axis along which to select values. If ``axis`` is negative, the + function determines the axis along which to select values by + counting from the last dimension. Default: ``-1``. + mode (str, optional): + How out-of-bounds indices will be handled. Possible values + are: + + - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps + negative indices. + - ``"clip"``: clips indices to (``0 <= i < n``). + + Default: ``"wrap"``. + + Returns: + usm_ndarray: + an array having the same data type as ``x``. The returned array has + the same rank (i.e., number of dimensions) as ``x`` and a shape + determined according to :ref:`broadcasting`, except for the axis + (dimension) specified by ``axis`` whose size must equal the size + of the corresponding axis (dimension) in ``indices``. + + Note: + Treatment of the out-of-bound indices in ``indices`` array is controlled + by the value of ``mode`` keyword. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError + if not isinstance(indices, dpt.usm_ndarray): + raise TypeError + x_nd = x.ndim + if x_nd != indices.ndim: + raise ValueError + pp = normalize_axis_index(operator.index(axis), x_nd) + out_usm_type = dpctl.utils.get_coerced_usm_type( + (x.usm_type, indices.usm_type) + ) + exec_q = dpctl.utils.get_execution_queue((x.sycl_queue, indices.sycl_queue)) + if exec_q is None: + raise dpctl.utils.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments. " + ) + mode_i = _get_indexing_mode(mode) + indexes_dt = ti.default_device_index_type(exec_q.sycl_device) + _ind = tuple( + ( + indices + if i == pp + else _range(x.shape[i], i, x_nd, exec_q, out_usm_type, indexes_dt) + ) + for i in range(x_nd) + ) + return _take_multi_index(x, _ind, 0, mode=mode_i) From e73a538cf7f4b412c9cfdadb5a1376ea5207c9be Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 1 Aug 2024 16:08:28 -0500 Subject: [PATCH 02/10] Add take_along_axis to index of dpctl.tensor.indexing_functions --- .../api_reference/dpctl/tensor.indexing_functions.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/doc_sources/api_reference/dpctl/tensor.indexing_functions.rst b/docs/doc_sources/api_reference/dpctl/tensor.indexing_functions.rst index f4a35d40db..09287ba49f 100644 --- a/docs/doc_sources/api_reference/dpctl/tensor.indexing_functions.rst +++ b/docs/doc_sources/api_reference/dpctl/tensor.indexing_functions.rst @@ -15,3 +15,4 @@ by either integral arrays of indices or boolean mask arrays. place put take + take_along_axis From 92ca43861eac6f0b49f21c7fdb936f5f6093221d Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 1 Aug 2024 17:10:51 -0500 Subject: [PATCH 03/10] Add mode keyword to _take_multi_index with default 0 --- dpctl/tensor/_copy_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index cbe86ad06b..dc5e7268a4 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -795,13 +795,18 @@ def _nonzero_impl(ary): return res -def _take_multi_index(ary, inds, p): +def _take_multi_index(ary, inds, p, mode=0): if not isinstance(ary, dpt.usm_ndarray): raise TypeError( f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}" ) ary_nd = ary.ndim p = normalize_axis_index(operator.index(p), ary_nd) + mode = operator.index(mode) + if mode not in [0, 1]: + raise ValueError( + "Invalid value for mode keyword, only 0 or 1 is supported" + ) queues_ = [ ary.sycl_queue, ] @@ -860,7 +865,7 @@ def _take_multi_index(ary, inds, p): ind=inds, dst=res, axis_start=p, - mode=0, + mode=mode, sycl_queue=exec_q, depends=dep_ev, ) From c8479dd08247211bc1b71dd481ed9bc7659b4cb2 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 1 Aug 2024 17:14:25 -0500 Subject: [PATCH 04/10] Basic test for take_along_axis added --- dpctl/tests/test_usm_ndarray_indexing.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 5fb7da4543..e5ffdc36a5 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1537,3 +1537,23 @@ def test_advanced_integer_indexing_cast_indices(): x = dpt.ones((3, 4, 5, 6), dtype="i4") with pytest.raises(ValueError): x[inds0, inds1, inds2, ...] + + +def test_take_along_axis(): + get_queue_or_skip() + + n0, n1, n2 = 3, 5, 7 + x = dpt.reshape(dpt.arange(n0 * n1 * n2), (n0, n1, n2)) + ind_dt = dpt.__array_namespace_info__().default_dtypes( + device=x.sycl_device + )["indexing"] + ind0 = dpt.ones((1, n1, n2), dtype=ind_dt) + ind1 = dpt.ones((n0, 1, n2), dtype=ind_dt) + ind2 = dpt.ones((n0, n1, 1), dtype=ind_dt) + + y0 = dpt.take_along_axis(x, ind0, axis=0) + assert y0.shape == ind0.shape + y1 = dpt.take_along_axis(x, ind1, axis=1) + assert y1.shape == ind1.shape + y2 = dpt.take_along_axis(x, ind2, axis=2) + assert y2.shape == ind2.shape From c139d2c60367bab7aac1b4147c0e850454dcea76 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Thu, 1 Aug 2024 17:36:31 -0500 Subject: [PATCH 05/10] Add take_along_axis arg validation test --- dpctl/tests/test_usm_ndarray_indexing.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index e5ffdc36a5..8d487f74ac 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1557,3 +1557,20 @@ def test_take_along_axis(): assert y1.shape == ind1.shape y2 = dpt.take_along_axis(x, ind2, axis=2) assert y2.shape == ind2.shape + + +def test_take_along_axis_validation(): + with pytest.raises(TypeError): + dpt.take_along_axis(tuple(), list()) + get_queue_or_skip() + x = dpt.ones(10) + with pytest.raises(TypeError): + dpt.take_along_axis(x, list()) + ind_dt = dpt.__array_namespace_info__().default_dtypes( + device=x.sycl_device + )["indexing"] + ind = dpt.zeros(1, dtype=ind_dt) + with pytest.raises(ValueError): + dpt.take_along_axis(x, ind, axis=1) + with pytest.raises(ValueError): + dpt.take_along_axis(x, ind, axis=0, mode="invalid") From bafe1251294dde41dfdb964c2421c77dffe5f662 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Fri, 2 Aug 2024 13:47:18 -0500 Subject: [PATCH 06/10] Fill exception messages --- dpctl/tensor/_indexing_functions.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/dpctl/tensor/_indexing_functions.py b/dpctl/tensor/_indexing_functions.py index 5718f2b4dd..a0ac2bb6cb 100644 --- a/dpctl/tensor/_indexing_functions.py +++ b/dpctl/tensor/_indexing_functions.py @@ -470,12 +470,17 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"): by the value of ``mode`` keyword. """ if not isinstance(x, dpt.usm_ndarray): - raise TypeError + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") if not isinstance(indices, dpt.usm_ndarray): - raise TypeError + raise TypeError( + f"Expected dpctl.tensor.usm_ndarray, got {type(indices)}" + ) x_nd = x.ndim if x_nd != indices.ndim: - raise ValueError + raise ValueError( + "Number of dimensions in the first and the second " + "argument arrays must be equal" + ) pp = normalize_axis_index(operator.index(axis), x_nd) out_usm_type = dpctl.utils.get_coerced_usm_type( (x.usm_type, indices.usm_type) From b6e4ea356e8d43ed16735a9ed97e3adcee3e1113 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Fri, 2 Aug 2024 13:53:20 -0500 Subject: [PATCH 07/10] Expand tests to improve coverage based on coverage results --- dpctl/tests/test_usm_ndarray_indexing.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 8d487f74ac..03557fa77c 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1560,17 +1560,31 @@ def test_take_along_axis(): def test_take_along_axis_validation(): + # type check on the first argument with pytest.raises(TypeError): dpt.take_along_axis(tuple(), list()) get_queue_or_skip() - x = dpt.ones(10) + n1, n2 = 2, 5 + x = dpt.ones(n1 * n2) + # type check on the second argument with pytest.raises(TypeError): dpt.take_along_axis(x, list()) - ind_dt = dpt.__array_namespace_info__().default_dtypes( - device=x.sycl_device - )["indexing"] + x_dev = x.sycl_device + info_ = dpt.__array_namespace_info__() + def_dtypes = info_.default_dtypes(device=x_dev) + ind_dt = def_dtypes["indexing"] ind = dpt.zeros(1, dtype=ind_dt) + # axis valudation with pytest.raises(ValueError): dpt.take_along_axis(x, ind, axis=1) + # mode validation with pytest.raises(ValueError): dpt.take_along_axis(x, ind, axis=0, mode="invalid") + # same array-ranks validation + with pytest.raises(ValueError): + dpt.take_along_axis(dpt.reshape(x, (n1, n2)), ind) + # check compute-follows-data + q2 = dpctl.SyclQueue(x_dev, property="enable_profiling") + ind2 = dpt.zeros(1, dtype=ind_dt, sycl_queue=q2) + with pytest.raises(ExecutionPlacementError): + dpt.take_along_axis(x, ind2) From a9a261e33c92cc195055a6ca2d4be1afb0fd3499 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Fri, 2 Aug 2024 14:07:52 -0500 Subject: [PATCH 08/10] Exercise setitems for some checks made for getitem --- dpctl/tests/test_usm_ndarray_indexing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 03557fa77c..0f24c7638e 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1535,8 +1535,12 @@ def test_advanced_integer_indexing_cast_indices(): inds1 = dpt.astype(inds0, "u4") inds2 = dpt.astype(inds0, "u8") x = dpt.ones((3, 4, 5, 6), dtype="i4") + # test getitem with pytest.raises(ValueError): x[inds0, inds1, inds2, ...] + # test setitem + with pytest.raises(ValueError): + x[inds0, inds1, inds2, ...] = 1 def test_take_along_axis(): From 0e69998a219e9e22d8edad9f0368881897224137 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Fri, 2 Aug 2024 14:55:43 -0500 Subject: [PATCH 09/10] Add tests of internal functions to improve coveage of _copy_utils --- dpctl/tests/test_usm_ndarray_indexing.py | 91 ++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_indexing.py b/dpctl/tests/test_usm_ndarray_indexing.py index 0f24c7638e..9fb2f04946 100644 --- a/dpctl/tests/test_usm_ndarray_indexing.py +++ b/dpctl/tests/test_usm_ndarray_indexing.py @@ -1592,3 +1592,94 @@ def test_take_along_axis_validation(): ind2 = dpt.zeros(1, dtype=ind_dt, sycl_queue=q2) with pytest.raises(ExecutionPlacementError): dpt.take_along_axis(x, ind2) + + +def check__extract_impl_validation(fn): + x = dpt.ones(10) + ind = dpt.ones(10, dtype="?") + with pytest.raises(TypeError): + fn(list(), ind) + with pytest.raises(TypeError): + fn(x, list()) + q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling") + ind2 = dpt.ones(10, dtype="?", sycl_queue=q2) + with pytest.raises(ExecutionPlacementError): + fn(x, ind2) + with pytest.raises(ValueError): + fn(x, ind, 1) + + +def check__nonzero_impl_validation(fn): + with pytest.raises(TypeError): + fn(list()) + + +def check__take_multi_index(fn): + x = dpt.ones(10) + x_dev = x.sycl_device + info_ = dpt.__array_namespace_info__() + def_dtypes = info_.default_dtypes(device=x_dev) + ind_dt = def_dtypes["indexing"] + ind = dpt.arange(10, dtype=ind_dt) + with pytest.raises(TypeError): + fn(list(), tuple(), 1) + with pytest.raises(ValueError): + fn(x, (ind,), 0, mode=2) + with pytest.raises(ValueError): + fn(x, (None,), 1) + with pytest.raises(IndexError): + fn(x, (x,), 1) + q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling") + ind2 = dpt.arange(10, dtype=ind_dt, sycl_queue=q2) + with pytest.raises(ExecutionPlacementError): + fn(x, (ind2,), 0) + m = dpt.ones((10, 10)) + ind_1 = dpt.arange(10, dtype="i8") + ind_2 = dpt.arange(10, dtype="u8") + with pytest.raises(ValueError): + fn(m, (ind_1, ind_2), 0) + + +def check__place_impl_validation(fn): + with pytest.raises(TypeError): + fn(list(), list(), list()) + x = dpt.ones(10) + with pytest.raises(TypeError): + fn(x, list(), list()) + q2 = dpctl.SyclQueue(x.sycl_device, property="enable_profiling") + mask2 = dpt.ones(10, dtype="?", sycl_queue=q2) + with pytest.raises(ExecutionPlacementError): + fn(x, mask2, 1) + x2 = dpt.ones((5, 5)) + mask2 = dpt.ones((5, 5), dtype="?") + with pytest.raises(ValueError): + fn(x2, mask2, x2, axis=1) + + +def check__put_multi_index_validation(fn): + with pytest.raises(TypeError): + fn(list(), list(), 0, list()) + x = dpt.ones(10) + inds = dpt.arange(10, dtype="i8") + vals = dpt.zeros(10) + # test inds which is not a tuple/list + fn(x, inds, 0, vals) + x2 = dpt.ones((5, 5)) + ind1 = dpt.arange(5, dtype="i8") + ind2 = dpt.arange(5, dtype="u8") + with pytest.raises(ValueError): + fn(x2, (ind1, ind2), 0, x2) + with pytest.raises(TypeError): + fn(x2, (ind1, list()), 0, x2) + + +def test__copy_utils(): + import dpctl.tensor._copy_utils as cu + + get_queue_or_skip() + + check__extract_impl_validation(cu._extract_impl) + check__nonzero_impl_validation(cu._nonzero_impl) + check__take_multi_index(cu._take_multi_index) + check__place_impl_validation(cu._place_impl) + check__put_multi_index_validation(cu._put_multi_index) From d07de324a3583b899c9d26567b56c722a92f97eb Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Sat, 3 Aug 2024 11:33:03 -0500 Subject: [PATCH 10/10] Changed argsort tests to use take_along_axis --- dpctl/tests/test_usm_ndarray_sorting.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_sorting.py b/dpctl/tests/test_usm_ndarray_sorting.py index fa73dcfdfa..088780d103 100644 --- a/dpctl/tests/test_usm_ndarray_sorting.py +++ b/dpctl/tests/test_usm_ndarray_sorting.py @@ -177,12 +177,24 @@ def test_argsort_axis0(): x = dpt.reshape(xf, (n, m)) idx = dpt.argsort(x, axis=0) - conseq_idx = dpt.arange(m, dtype=idx.dtype) - s = x[idx, conseq_idx[dpt.newaxis, :]] + s = dpt.take_along_axis(x, idx, axis=0) assert dpt.all(s[:-1, :] <= s[1:, :]) +def test_argsort_axis1(): + get_queue_or_skip() + + n, m = 200, 30 + xf = dpt.arange(n * m, 0, step=-1, dtype="i4") + x = dpt.reshape(xf, (n, m)) + idx = dpt.argsort(x, axis=1) + + s = dpt.take_along_axis(x, idx, axis=1) + + assert dpt.all(s[:, :-1] <= s[:, 1:]) + + def test_sort_strided(): get_queue_or_skip() @@ -199,8 +211,9 @@ def test_argsort_strided(): x_orig = dpt.arange(100, dtype="i4") x_flipped = dpt.flip(x_orig, axis=0) idx = dpt.argsort(x_flipped) + s = dpt.take_along_axis(x_flipped, idx, axis=0) - assert dpt.all(x_flipped[idx] == x_orig) + assert dpt.all(s == x_orig) def test_sort_0d_array():