Skip to content

Commit 614af33

Browse files
authored
Get rid of falling back on numpy in dpnp.put (#1838)
* Get rid of call_origin in dpnp.put * Extended tests for dpnp.put
1 parent 069cad2 commit 614af33

File tree

4 files changed

+229
-188
lines changed

4 files changed

+229
-188
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 65 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def nonzero(a):
677677
[2, 1]])
678678
679679
A common use for ``nonzero`` is to find the indices of an array, where
680-
a condition is ``True.`` Given an array `a`, the condition `a` > 3 is
680+
a condition is ``True``. Given an array `a`, the condition `a` > 3 is
681681
a boolean array and since ``False`` is interpreted as ``0``,
682682
``np.nonzero(a > 3)`` yields the indices of the `a` where the condition is
683683
true.
@@ -736,25 +736,33 @@ def place(x, mask, vals, /):
736736
return call_origin(numpy.place, x, mask, vals, dpnp_inplace=True)
737737

738738

739-
# pylint: disable=redefined-outer-name
740-
def put(a, indices, vals, /, *, axis=None, mode="wrap"):
739+
def put(a, ind, v, /, *, axis=None, mode="wrap"):
741740
"""
742741
Puts values of an array into another array along a given axis.
743742
744743
For full documentation refer to :obj:`numpy.put`.
745744
746-
Limitations
747-
-----------
748-
Parameters `a` and `indices` are supported either as :class:`dpnp.ndarray`
749-
or :class:`dpctl.tensor.usm_ndarray`.
750-
Parameter `indices` is supported as 1-D array of integer data type.
751-
Parameter `vals` must be broadcastable to the shape of `indices`
752-
and has the same data type as `a` if it is as :class:`dpnp.ndarray`
753-
or :class:`dpctl.tensor.usm_ndarray`.
754-
Parameter `mode` is supported with ``wrap``, the default, and ``clip``
755-
values.
756-
Parameter `axis` is supported as integer only.
757-
Otherwise the function will be executed sequentially on CPU.
745+
Parameters
746+
----------
747+
a : {dpnp.ndarray, usm_ndarray}
748+
The array the values will be put into.
749+
ind : {array_like}
750+
Target indices, interpreted as integers.
751+
v : {scalar, array_like}
752+
Values to be put into `a`. Must be broadcastable to the result shape
753+
``a.shape[:axis] + ind.shape + a.shape[axis+1:]``.
754+
axis {None, int}, optional
755+
The axis along which the values will be placed. If `a` is 1-D array,
756+
this argument is optional.
757+
Default: ``None``.
758+
mode : {'wrap', 'clip'}, optional
759+
Specifies how out-of-bounds indices will behave.
760+
761+
- 'wrap': clamps indices to (``-n <= i < n``), then wraps negative
762+
indices.
763+
- 'clip': clips indices to (``0 <= i < n``).
764+
765+
Default: ``'wrap'``.
758766
759767
See Also
760768
--------
@@ -774,49 +782,53 @@ def put(a, indices, vals, /, *, axis=None, mode="wrap"):
774782
Examples
775783
--------
776784
>>> import dpnp as np
777-
>>> x = np.arange(5)
778-
>>> indices = np.array([0, 1])
779-
>>> np.put(x, indices, [-44, -55])
780-
>>> x
781-
array([-44, -55, 2, 3, 4])
785+
>>> a = np.arange(5)
786+
>>> np.put(a, [0, 2], [-44, -55])
787+
>>> a
788+
array([-44, 1, -55, 3, 4])
782789
783-
>>> x = np.arange(5)
784-
>>> indices = np.array([22])
785-
>>> np.put(x, indices, -5, mode='clip')
786-
>>> x
790+
>>> a = np.arange(5)
791+
>>> np.put(a, 22, -5, mode='clip')
792+
>>> a
787793
array([ 0, 1, 2, 3, -5])
788794
789795
"""
790796

791-
if dpnp.is_supported_array_type(a) and dpnp.is_supported_array_type(
792-
indices
793-
):
794-
if indices.ndim != 1 or not dpnp.issubdtype(
795-
indices.dtype, dpnp.integer
796-
):
797-
pass
798-
elif mode not in ("clip", "wrap"):
799-
pass
800-
elif axis is not None and not isinstance(axis, int):
801-
raise TypeError(f"`axis` must be of integer type, got {type(axis)}")
802-
# TODO: remove when #1382(dpctl) is solved
803-
elif dpnp.is_supported_array_type(vals) and a.dtype != vals.dtype:
804-
pass
805-
else:
806-
if axis is None and a.ndim > 1:
807-
a = dpnp.reshape(a, -1)
808-
dpt_array = dpnp.get_usm_ndarray(a)
809-
dpt_indices = dpnp.get_usm_ndarray(indices)
810-
dpt_vals = (
811-
dpnp.get_usm_ndarray(vals)
812-
if isinstance(vals, dpnp_array)
813-
else vals
814-
)
815-
return dpt.put(
816-
dpt_array, dpt_indices, dpt_vals, axis=axis, mode=mode
817-
)
797+
dpnp.check_supported_arrays_type(a)
798+
799+
if not dpnp.is_supported_array_type(ind):
800+
ind = dpnp.asarray(
801+
ind, dtype=dpnp.intp, sycl_queue=a.sycl_queue, usm_type=a.usm_type
802+
)
803+
elif not dpnp.issubdtype(ind.dtype, dpnp.integer):
804+
ind = dpnp.astype(ind, dtype=dpnp.intp, casting="safe")
805+
ind = dpnp.ravel(ind)
806+
807+
if not dpnp.is_supported_array_type(v):
808+
v = dpnp.asarray(
809+
v, dtype=a.dtype, sycl_queue=a.sycl_queue, usm_type=a.usm_type
810+
)
811+
if v.size == 0:
812+
return
813+
814+
if not (axis is None or isinstance(axis, int)):
815+
raise TypeError(f"`axis` must be of integer type, got {type(axis)}")
816+
817+
in_a = a
818+
if axis is None and a.ndim > 1:
819+
a = dpnp.ravel(in_a)
820+
821+
if mode not in ("wrap", "clip"):
822+
raise ValueError(
823+
f"clipmode must be one of 'clip' or 'wrap' (got '{mode}')"
824+
)
818825

819-
return call_origin(numpy.put, a, indices, vals, mode, dpnp_inplace=True)
826+
usm_a = dpnp.get_usm_ndarray(a)
827+
usm_ind = dpnp.get_usm_ndarray(ind)
828+
usm_v = dpnp.get_usm_ndarray(v)
829+
dpt.put(usm_a, usm_ind, usm_v, axis=axis, mode=mode)
830+
if in_a is not a:
831+
in_a[:] = a.reshape(in_a.shape, copy=False)
820832

821833

822834
# pylint: disable=redefined-outer-name
@@ -1194,7 +1206,7 @@ def triu_indices(n, k=0, m=None):
11941206
-------
11951207
inds : tuple, shape(2) of ndarrays, shape(`n`)
11961208
The indices for the triangle. The returned tuple contains two arrays,
1197-
each with the indices along one dimension of the array. Can be used
1209+
each with the indices along one dimension of the array. Can be used
11981210
to slice a ndarray of shape(`n`, `n`).
11991211
"""
12001212

tests/helper.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,6 @@ def get_integer_dtypes():
9191
return [dpnp.int32, dpnp.int64]
9292

9393

94-
def get_integer_dtypes():
95-
"""
96-
Build a list of integer types supported by DPNP.
97-
"""
98-
99-
return [dpnp.int32, dpnp.int64]
100-
101-
10294
def get_complex_dtypes(device=None):
10395
"""
10496
Build a list of complex types supported by DPNP based on device capabilities.

0 commit comments

Comments
 (0)