Skip to content

Commit f19e989

Browse files
authored
Update documentation and clean up implementation of indexing functions (#1913)
* Remove limitations from dpnp.take implementation * Add more test to cover specail cases and increase code coverage * Applied pre-commit hook * Corrected test_over_index * Update docsctrings with resolving typos * Use dpnp.reshape() to change shape and create dpnp array from usm_ndarray result * Remove limitations from dpnp.place implementation * Update relating tests * Roll back changed in dpnp.vander * Remove data sync at the end of function * Update indexing functions * Add missing test scenario * Updated docstring in put_along_axis() and take_along_axis() and rolled back data synchronization * Remove data synchronization for dpnp.put() * Remove data synchronization for dpnp.nonzero() * Remove data synchronization for dpnp.indices() * Remove data synchronization for dpnp.extract() * Remove data sync in dpnp.get_result_array()
1 parent e353d7d commit f19e989

File tree

2 files changed

+89
-49
lines changed

2 files changed

+89
-49
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 87 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def diag_indices(n, ndim=2, device=None, usm_type="device", sycl_queue=None):
216216
217217
See also
218218
--------
219-
:obj:`diag_indices_from` : Return the indices to access the main
220-
diagonal of an n-dimensional array.
219+
:obj:`dpnp.diag_indices_from` : Return the indices to access the main
220+
diagonal of an n-dimensional array.
221221
222222
Examples
223223
--------
@@ -276,7 +276,7 @@ def diag_indices_from(arr):
276276
Parameters
277277
----------
278278
arr : {dpnp.ndarray, usm_ndarray}
279-
Array at least 2-D
279+
Array at least 2-D.
280280
281281
Returns
282282
-------
@@ -285,8 +285,8 @@ def diag_indices_from(arr):
285285
286286
See also
287287
--------
288-
:obj:`diag_indices` : Return the indices to access the main
289-
diagonal of an array.
288+
:obj:`dpnp.diag_indices` : Return the indices to access the main diagonal
289+
of an array.
290290
291291
Examples
292292
--------
@@ -570,14 +570,17 @@ def extract(condition, a):
570570

571571
usm_res = dpt.extract(usm_cond, usm_a)
572572

573-
dpnp.synchronize_array_data(usm_res)
574573
return dpnp_array._create_from_usm_ndarray(usm_res)
575574

576575

577576
def fill_diagonal(a, val, wrap=False):
578577
"""
579578
Fill the main diagonal of the given array of any dimensionality.
580579
580+
For an array `a` with ``a.ndim >= 2``, the diagonal is the list of values
581+
``a[i, ..., i]`` with indices ``i`` all identical. This function modifies
582+
the input array in-place without returning a value.
583+
581584
For full documentation refer to :obj:`numpy.fill_diagonal`.
582585
583586
Parameters
@@ -678,11 +681,12 @@ def fill_diagonal(a, val, wrap=False):
678681
679682
"""
680683

681-
dpnp.check_supported_arrays_type(a)
682-
dpnp.check_supported_arrays_type(val, scalar_type=True, all_scalars=True)
684+
usm_a = dpnp.get_usm_ndarray(a)
685+
usm_val = dpnp.get_usm_ndarray_or_scalar(val)
683686

684687
if a.ndim < 2:
685688
raise ValueError("array must be at least 2-d")
689+
686690
end = a.size
687691
if a.ndim == 2:
688692
step = a.shape[1] + 1
@@ -695,18 +699,21 @@ def fill_diagonal(a, val, wrap=False):
695699

696700
# TODO: implement flatiter for slice key
697701
# a.flat[:end:step] = val
702+
# but need to consider use case when `a` is usm_ndarray also
698703
a_sh = a.shape
699-
tmp_a = dpnp.ravel(a)
700-
if dpnp.isscalar(val):
701-
tmp_a[:end:step] = val
704+
tmp_a = dpt.reshape(usm_a, -1)
705+
if dpnp.isscalar(usm_val):
706+
tmp_a[:end:step] = usm_val
702707
else:
703-
flat_val = val.ravel()
708+
usm_val = dpt.reshape(usm_val, -1)
709+
704710
# Setitem can work only if index size equal val size.
705711
# Using loop for general case without dependencies of val size.
706-
for i in range(0, flat_val.size):
707-
tmp_a[step * i : end : step * (i + 1)] = flat_val[i]
708-
tmp_a = dpnp.reshape(tmp_a, a_sh)
709-
a[:] = tmp_a
712+
for i in range(0, usm_val.size):
713+
tmp_a[step * i : end : step * (i + 1)] = usm_val[i]
714+
715+
tmp_a = dpt.reshape(tmp_a, a_sh)
716+
usm_a[:] = tmp_a
710717

711718

712719
def indices(
@@ -758,6 +765,13 @@ def indices(
758765
with grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1)
759766
with dimensions[i] in the i-th place.
760767
768+
See Also
769+
--------
770+
:obj:`dpnp.mgrid` : Return a dense multi-dimensional “meshgrid”.
771+
:obj:`dpnp.ogrid` : Return an open multi-dimensional “meshgrid”.
772+
:obj:`dpnp.meshgrid` : Return a tuple of coordinate matrices from
773+
coordinate vectors.
774+
761775
Examples
762776
--------
763777
>>> import dpnp as np
@@ -800,6 +814,7 @@ def indices(
800814
dimensions = tuple(dimensions)
801815
n = len(dimensions)
802816
shape = (1,) * n
817+
803818
if sparse:
804819
res = ()
805820
else:
@@ -810,6 +825,7 @@ def indices(
810825
usm_type=usm_type,
811826
sycl_queue=sycl_queue,
812827
)
828+
813829
for i, dim in enumerate(dimensions):
814830
idx = dpnp.arange(
815831
dim,
@@ -818,6 +834,7 @@ def indices(
818834
usm_type=usm_type,
819835
sycl_queue=sycl_queue,
820836
).reshape(shape[:i] + (dim,) + shape[i + 1 :])
837+
821838
if sparse:
822839
res = res + (idx,)
823840
else:
@@ -927,10 +944,12 @@ def nonzero(a):
927944
"""
928945
Return the indices of the elements that are non-zero.
929946
930-
Returns a tuple of arrays, one for each dimension of `a`,
931-
containing the indices of the non-zero elements in that
932-
dimension. The values in `a` are always tested and returned in
933-
row-major, C-style order.
947+
Returns a tuple of arrays, one for each dimension of `a`, containing
948+
the indices of the non-zero elements in that dimension. The values in `a`
949+
are always tested and returned in row-major, C-style order.
950+
951+
To group the indices by element, rather than dimension, use
952+
:obj:`dpnp.argwhere`, which returns a row for each non-zero element.
934953
935954
For full documentation refer to :obj:`numpy.nonzero`.
936955
@@ -1005,9 +1024,9 @@ def nonzero(a):
10051024
10061025
"""
10071026

1008-
usx_a = dpnp.get_usm_ndarray(a)
1027+
usm_a = dpnp.get_usm_ndarray(a)
10091028
return tuple(
1010-
dpnp_array._create_from_usm_ndarray(y) for y in dpt.nonzero(usx_a)
1029+
dpnp_array._create_from_usm_ndarray(y) for y in dpt.nonzero(usm_a)
10111030
)
10121031

10131032

@@ -1139,47 +1158,60 @@ def put(a, ind, v, /, *, axis=None, mode="wrap"):
11391158
11401159
"""
11411160

1142-
dpnp.check_supported_arrays_type(a)
1143-
1144-
if not dpnp.is_supported_array_type(ind):
1145-
ind = dpnp.asarray(
1146-
ind, dtype=dpnp.intp, sycl_queue=a.sycl_queue, usm_type=a.usm_type
1147-
)
1148-
elif not dpnp.issubdtype(ind.dtype, dpnp.integer):
1149-
ind = dpnp.astype(ind, dtype=dpnp.intp, casting="safe")
1150-
ind = dpnp.ravel(ind)
1151-
1152-
if not dpnp.is_supported_array_type(v):
1153-
v = dpnp.asarray(
1154-
v, dtype=a.dtype, sycl_queue=a.sycl_queue, usm_type=a.usm_type
1155-
)
1156-
if v.size == 0:
1157-
return
1161+
usm_a = dpnp.get_usm_ndarray(a)
11581162

11591163
if not (axis is None or isinstance(axis, int)):
11601164
raise TypeError(f"`axis` must be of integer type, got {type(axis)}")
11611165

1162-
in_a = a
1163-
if axis is None and a.ndim > 1:
1164-
a = dpnp.ravel(in_a)
1165-
11661166
if mode not in ("wrap", "clip"):
11671167
raise ValueError(
11681168
f"clipmode must be one of 'clip' or 'wrap' (got '{mode}')"
11691169
)
11701170

1171-
usm_a = dpnp.get_usm_ndarray(a)
1172-
usm_ind = dpnp.get_usm_ndarray(ind)
1173-
usm_v = dpnp.get_usm_ndarray(v)
1171+
usm_v = dpnp.as_usm_ndarray(
1172+
v,
1173+
dtype=usm_a.dtype,
1174+
usm_type=usm_a.usm_type,
1175+
sycl_queue=usm_a.sycl_queue,
1176+
)
1177+
if usm_v.size == 0:
1178+
return
1179+
1180+
usm_ind = dpnp.as_usm_ndarray(
1181+
ind,
1182+
dtype=dpnp.intp,
1183+
usm_type=usm_a.usm_type,
1184+
sycl_queue=usm_a.sycl_queue,
1185+
)
1186+
1187+
if usm_ind.ndim != 1:
1188+
# dpt.put supports only 1-D array of indices
1189+
usm_ind = dpt.reshape(usm_ind, -1, copy=False)
1190+
1191+
if not dpnp.issubdtype(usm_ind.dtype, dpnp.integer):
1192+
# dpt.put supports only integer dtype for array of indices
1193+
usm_ind = dpt.astype(usm_ind, dpnp.intp, casting="safe")
1194+
1195+
in_usm_a = usm_a
1196+
if axis is None and usm_a.ndim > 1:
1197+
usm_a = dpt.reshape(usm_a, -1)
1198+
11741199
dpt.put(usm_a, usm_ind, usm_v, axis=axis, mode=mode)
1175-
if in_a is not a:
1176-
in_a[:] = a.reshape(in_a.shape, copy=False)
1200+
if in_usm_a._pointer != usm_a._pointer: # pylint: disable=protected-access
1201+
in_usm_a[:] = dpt.reshape(usm_a, in_usm_a.shape, copy=False)
11771202

11781203

11791204
def put_along_axis(a, ind, values, axis):
11801205
"""
11811206
Put values into the destination array by matching 1d index and data slices.
11821207
1208+
This iterates over matching 1d slices oriented along the specified axis in
1209+
the index and data arrays, and uses the former to place values into the
1210+
latter. These slices can be different lengths.
1211+
1212+
Functions returning an index along an `axis`, like :obj:`dpnp.argsort` and
1213+
:obj:`dpnp.argpartition`, produce suitable indices for this function.
1214+
11831215
For full documentation refer to :obj:`numpy.put_along_axis`.
11841216
11851217
Parameters
@@ -1415,6 +1447,13 @@ def take_along_axis(a, indices, axis):
14151447
"""
14161448
Take values from the input array by matching 1d index and data slices.
14171449
1450+
This iterates over matching 1d slices oriented along the specified axis in
1451+
the index and data arrays, and uses the former to look up values in the
1452+
latter. These slices can be different lengths.
1453+
1454+
Functions returning an index along an `axis`, like :obj:`dpnp.argsort` and
1455+
:obj:`dpnp.argpartition`, produce suitable indices for this function.
1456+
14181457
For full documentation refer to :obj:`numpy.take_along_axis`.
14191458
14201459
Parameters
@@ -1428,7 +1467,7 @@ def take_along_axis(a, indices, axis):
14281467
axis : int
14291468
The axis to take 1d slices along. If axis is ``None``, the input
14301469
array is treated as if it had first been flattened to 1d,
1431-
for consistency with `sort` and `argsort`.
1470+
for consistency with :obj:`dpnp.sort` and :obj:`dpnp.argsort`.
14321471
14331472
Returns
14341473
-------

tests/test_indexing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,8 +780,9 @@ def test_fill_diagonal(array, val):
780780

781781
@pytest.mark.parametrize(
782782
"dimension",
783-
[(1,), (2,), (1, 2), (2, 3), (3, 2), [1], [2], [1, 2], [2, 3], [3, 2]],
783+
[(), (1,), (2,), (1, 2), (2, 3), (3, 2), [1], [2], [1, 2], [2, 3], [3, 2]],
784784
ids=[
785+
"()",
785786
"(1, )",
786787
"(2, )",
787788
"(1, 2)",

0 commit comments

Comments
 (0)