Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
nonzero,
place,
put,
put_along_axis,
take,
take_along_axis,
)
Expand Down Expand Up @@ -385,4 +386,5 @@
"count_nonzero",
"DLDeviceType",
"take_along_axis",
"put_along_axis",
]
9 changes: 7 additions & 2 deletions dpctl/tensor/_copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,13 +938,18 @@ def _place_impl(ary, ary_mask, vals, axis=0):
return


def _put_multi_index(ary, inds, p, vals):
def _put_multi_index(ary, inds, p, vals, mode=0):
if not isinstance(ary, dpt.usm_ndarray):
raise TypeError(
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
)
ary_nd = ary.ndim
p = normalize_axis_index(operator.index(p), ary_nd)
mode = operator.index(mode)
if mode not in [0, 1]:
raise ValueError(
"Invalid value for mode keyword, only 0 or 1 is supported"
)
if isinstance(vals, dpt.usm_ndarray):
queues_ = [ary.sycl_queue, vals.sycl_queue]
usm_types_ = [ary.usm_type, vals.usm_type]
Expand Down Expand Up @@ -1018,7 +1023,7 @@ def _put_multi_index(ary, inds, p, vals):
ind=inds,
val=rhs,
axis_start=p,
mode=0,
mode=mode,
sycl_queue=exec_q,
depends=dep_ev,
)
Expand Down
101 changes: 88 additions & 13 deletions dpctl/tensor/_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
import dpctl.tensor._tensor_impl as ti
import dpctl.utils

from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
from ._copy_utils import (
_extract_impl,
_nonzero_impl,
_put_multi_index,
_take_multi_index,
)
from ._numpy_helper import normalize_axis_index


Expand Down Expand Up @@ -206,22 +211,18 @@ def put_vec_duplicates(vec, ind, vals):
raise TypeError(
"Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x))
)
if isinstance(vals, dpt.usm_ndarray):
queues_ = [x.sycl_queue, vals.sycl_queue]
usm_types_ = [x.usm_type, vals.usm_type]
else:
queues_ = [
x.sycl_queue,
]
usm_types_ = [
x.usm_type,
]
if not isinstance(indices, dpt.usm_ndarray):
raise TypeError(
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
type(indices)
)
)
if isinstance(vals, dpt.usm_ndarray):
queues_ = [x.sycl_queue, indices.sycl_queue, vals.sycl_queue]
usm_types_ = [x.usm_type, indices.usm_type, vals.usm_type]
else:
queues_ = [x.sycl_queue, indices.sycl_queue]
usm_types_ = [x.usm_type, indices.usm_type]
if indices.ndim != 1:
raise ValueError(
"`indices` expected a 1D array, got `{}`".format(indices.ndim)
Expand All @@ -232,8 +233,6 @@ def put_vec_duplicates(vec, ind, vals):
indices.dtype
)
)
queues_.append(indices.sycl_queue)
usm_types_.append(indices.usm_type)
exec_q = dpctl.utils.get_execution_queue(queues_)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError
Expand Down Expand Up @@ -502,3 +501,79 @@ def take_along_axis(x, indices, /, *, axis=-1, mode="wrap"):
for i in range(x_nd)
)
return _take_multi_index(x, _ind, 0, mode=mode_i)


def put_along_axis(x, indices, vals, /, *, axis=-1, mode="wrap"):
"""
Puts elements into an array at the one-dimensional indices specified by
``indices`` along a provided ``axis``.

Args:
x (usm_ndarray):
input array. Must be compatible with ``indices``, except for the
axis (dimension) specified by ``axis``.
indices (usm_ndarray):
array indices. Must have the same rank (i.e., number of dimensions)
as ``x``.
vals (usm_ndarray):
Array of values to be put into ``x``.
Must be broadcastable to the shape of ``indices``.
axis: int
axis along which to select values. If ``axis`` is negative, the
function determines the axis along which to select values by
counting from the last dimension. Default: ``-1``.
mode (str, optional):
How out-of-bounds indices will be handled. Possible values
are:

- ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
negative indices.
- ``"clip"``: clips indices to (``0 <= i < n``).

Default: ``"wrap"``.

.. note::

If input array ``indices`` contains duplicates, a race condition
occurs, and the value written into corresponding positions in ``x``
may vary from run to run. Preserving sequential semantics in handing
the duplicates to achieve deterministic behavior requires additional
work.
"""
if not isinstance(x, dpt.usm_ndarray):
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
if not isinstance(indices, dpt.usm_ndarray):
raise TypeError(
f"Expected dpctl.tensor.usm_ndarray, got {type(indices)}"
)
x_nd = x.ndim
if x_nd != indices.ndim:
raise ValueError(
"Number of dimensions in the first and the second "
"argument arrays must be equal"
)
pp = normalize_axis_index(operator.index(axis), x_nd)
if isinstance(vals, dpt.usm_ndarray):
queues_ = [x.sycl_queue, indices.sycl_queue, vals.sycl_queue]
usm_types_ = [x.usm_type, indices.usm_type, vals.usm_type]
else:
queues_ = [x.sycl_queue, indices.sycl_queue]
usm_types_ = [x.usm_type, indices.usm_type]
exec_q = dpctl.utils.get_execution_queue(queues_)
if exec_q is None:
raise dpctl.utils.ExecutionPlacementError(
"Execution placement can not be unambiguously inferred "
"from input arguments. "
)
out_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
mode_i = _get_indexing_mode(mode)
indexes_dt = ti.default_device_index_type(exec_q.sycl_device)
_ind = tuple(
(
indices
if i == pp
else _range(x.shape[i], i, x_nd, exec_q, out_usm_type, indexes_dt)
)
for i in range(x_nd)
)
return _put_multi_index(x, _ind, 0, vals, mode=mode_i)
116 changes: 115 additions & 1 deletion dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,7 +1578,7 @@ def test_take_along_axis_validation():
def_dtypes = info_.default_dtypes(device=x_dev)
ind_dt = def_dtypes["indexing"]
ind = dpt.zeros(1, dtype=ind_dt)
# axis valudation
# axis validation
with pytest.raises(ValueError):
dpt.take_along_axis(x, ind, axis=1)
# mode validation
Expand All @@ -1594,6 +1594,116 @@ def test_take_along_axis_validation():
dpt.take_along_axis(x, ind2)


def test_put_along_axis():
get_queue_or_skip()

n0, n1, n2 = 3, 5, 7
x = dpt.reshape(dpt.arange(n0 * n1 * n2), (n0, n1, n2))
ind_dt = dpt.__array_namespace_info__().default_dtypes(
device=x.sycl_device
)["indexing"]
ind0 = dpt.ones((1, n1, n2), dtype=ind_dt)
ind1 = dpt.ones((n0, 1, n2), dtype=ind_dt)
ind2 = dpt.ones((n0, n1, 1), dtype=ind_dt)

xc = dpt.copy(x)
vals = dpt.ones(ind0.shape, dtype=x.dtype)
dpt.put_along_axis(xc, ind0, vals, axis=0)
assert dpt.all(dpt.take_along_axis(xc, ind0, axis=0) == vals)

xc = dpt.copy(x)
vals = dpt.ones(ind1.shape, dtype=x.dtype)
dpt.put_along_axis(xc, ind1, vals, axis=1)
assert dpt.all(dpt.take_along_axis(xc, ind1, axis=1) == vals)

xc = dpt.copy(x)
vals = dpt.ones(ind2.shape, dtype=x.dtype)
dpt.put_along_axis(xc, ind2, vals, axis=2)
assert dpt.all(dpt.take_along_axis(xc, ind2, axis=2) == vals)

xc = dpt.copy(x)
vals = dpt.ones(ind2.shape, dtype=x.dtype)
dpt.put_along_axis(xc, ind2, dpt.asnumpy(vals), axis=2)
assert dpt.all(dpt.take_along_axis(xc, ind2, axis=2) == vals)


def test_put_along_axis_validation():
# type check on the first argument
with pytest.raises(TypeError):
dpt.put_along_axis(tuple(), list(), list())
get_queue_or_skip()
n1, n2 = 2, 5
x = dpt.ones(n1 * n2)
# type check on the second argument
with pytest.raises(TypeError):
dpt.put_along_axis(x, list(), list())
x_dev = x.sycl_device
info_ = dpt.__array_namespace_info__()
def_dtypes = info_.default_dtypes(device=x_dev)
ind_dt = def_dtypes["indexing"]
ind = dpt.zeros(1, dtype=ind_dt)
vals = dpt.zeros(1, dtype=x.dtype)
# axis validation
with pytest.raises(ValueError):
dpt.put_along_axis(x, ind, vals, axis=1)
# mode validation
with pytest.raises(ValueError):
dpt.put_along_axis(x, ind, vals, axis=0, mode="invalid")
# same array-ranks validation
with pytest.raises(ValueError):
dpt.put_along_axis(dpt.reshape(x, (n1, n2)), ind, vals)
# check compute-follows-data
q2 = dpctl.SyclQueue(x_dev, property="enable_profiling")
ind2 = dpt.zeros(1, dtype=ind_dt, sycl_queue=q2)
with pytest.raises(ExecutionPlacementError):
dpt.put_along_axis(x, ind2, vals)


def test_put_along_axis_application():
get_queue_or_skip()
info_ = dpt.__array_namespace_info__()
def_dtypes = info_.default_dtypes(device=None)
ind_dt = def_dtypes["indexing"]
all_perms = dpt.asarray(
[
[0, 1, 2, 3],
[0, 2, 1, 3],
[2, 0, 1, 3],
[2, 1, 0, 3],
[1, 0, 2, 3],
[1, 2, 0, 3],
[0, 1, 3, 2],
[0, 2, 3, 1],
[2, 0, 3, 1],
[2, 1, 3, 0],
[1, 0, 3, 2],
[1, 2, 3, 0],
[0, 3, 1, 2],
[0, 3, 2, 1],
[2, 3, 0, 1],
[2, 3, 1, 0],
[1, 3, 0, 2],
[1, 3, 2, 0],
[3, 0, 1, 2],
[3, 0, 2, 1],
[3, 2, 0, 1],
[3, 2, 1, 0],
[3, 1, 0, 2],
[3, 1, 2, 0],
],
dtype=ind_dt,
)
p_mats = dpt.zeros((24, 4, 4), dtype=dpt.int64)
vals = dpt.ones((24, 4, 1), dtype=p_mats.dtype)
# form 24 permutation matrices
dpt.put_along_axis(p_mats, all_perms[..., dpt.newaxis], vals, axis=2)
p2 = p_mats @ p_mats
p4 = p2 @ p2
p8 = p4 @ p4
expected = dpt.eye(4, dtype=p_mats.dtype)[dpt.newaxis, ...]
assert dpt.all(p8 @ p4 == expected)


def check__extract_impl_validation(fn):
x = dpt.ones(10)
ind = dpt.ones(10, dtype="?")
Expand Down Expand Up @@ -1670,7 +1780,11 @@ def check__put_multi_index_validation(fn):
with pytest.raises(ValueError):
fn(x2, (ind1, ind2), 0, x2)
with pytest.raises(TypeError):
# invalid index type
fn(x2, (ind1, list()), 0, x2)
with pytest.raises(ValueError):
# invalid mode keyword value
fn(x, inds, 0, vals, mode=100)


def test__copy_utils():
Expand Down
Loading