@@ -216,8 +216,8 @@ def diag_indices(n, ndim=2, device=None, usm_type="device", sycl_queue=None):
216
216
217
217
See also
218
218
--------
219
- :obj:`diag_indices_from` : Return the indices to access the main
220
- diagonal of an n-dimensional array.
219
+ :obj:`dpnp. diag_indices_from` : Return the indices to access the main
220
+ diagonal of an n-dimensional array.
221
221
222
222
Examples
223
223
--------
@@ -276,7 +276,7 @@ def diag_indices_from(arr):
276
276
Parameters
277
277
----------
278
278
arr : {dpnp.ndarray, usm_ndarray}
279
- Array at least 2-D
279
+ Array at least 2-D.
280
280
281
281
Returns
282
282
-------
@@ -285,8 +285,8 @@ def diag_indices_from(arr):
285
285
286
286
See also
287
287
--------
288
- :obj:`diag_indices` : Return the indices to access the main
289
- diagonal of an array.
288
+ :obj:`dpnp. diag_indices` : Return the indices to access the main diagonal
289
+ of an array.
290
290
291
291
Examples
292
292
--------
@@ -570,14 +570,17 @@ def extract(condition, a):
570
570
571
571
usm_res = dpt .extract (usm_cond , usm_a )
572
572
573
- dpnp .synchronize_array_data (usm_res )
574
573
return dpnp_array ._create_from_usm_ndarray (usm_res )
575
574
576
575
577
576
def fill_diagonal (a , val , wrap = False ):
578
577
"""
579
578
Fill the main diagonal of the given array of any dimensionality.
580
579
580
+ For an array `a` with ``a.ndim >= 2``, the diagonal is the list of values
581
+ ``a[i, ..., i]`` with indices ``i`` all identical. This function modifies
582
+ the input array in-place without returning a value.
583
+
581
584
For full documentation refer to :obj:`numpy.fill_diagonal`.
582
585
583
586
Parameters
@@ -678,11 +681,12 @@ def fill_diagonal(a, val, wrap=False):
678
681
679
682
"""
680
683
681
- dpnp .check_supported_arrays_type (a )
682
- dpnp .check_supported_arrays_type (val , scalar_type = True , all_scalars = True )
684
+ usm_a = dpnp .get_usm_ndarray (a )
685
+ usm_val = dpnp .get_usm_ndarray_or_scalar (val )
683
686
684
687
if a .ndim < 2 :
685
688
raise ValueError ("array must be at least 2-d" )
689
+
686
690
end = a .size
687
691
if a .ndim == 2 :
688
692
step = a .shape [1 ] + 1
@@ -695,18 +699,21 @@ def fill_diagonal(a, val, wrap=False):
695
699
696
700
# TODO: implement flatiter for slice key
697
701
# a.flat[:end:step] = val
702
+ # but need to consider use case when `a` is usm_ndarray also
698
703
a_sh = a .shape
699
- tmp_a = dpnp . ravel ( a )
700
- if dpnp .isscalar (val ):
701
- tmp_a [:end :step ] = val
704
+ tmp_a = dpt . reshape ( usm_a , - 1 )
705
+ if dpnp .isscalar (usm_val ):
706
+ tmp_a [:end :step ] = usm_val
702
707
else :
703
- flat_val = val .ravel ()
708
+ usm_val = dpt .reshape (usm_val , - 1 )
709
+
704
710
# Setitem can work only if index size equal val size.
705
711
# Using loop for general case without dependencies of val size.
706
- for i in range (0 , flat_val .size ):
707
- tmp_a [step * i : end : step * (i + 1 )] = flat_val [i ]
708
- tmp_a = dpnp .reshape (tmp_a , a_sh )
709
- a [:] = tmp_a
712
+ for i in range (0 , usm_val .size ):
713
+ tmp_a [step * i : end : step * (i + 1 )] = usm_val [i ]
714
+
715
+ tmp_a = dpt .reshape (tmp_a , a_sh )
716
+ usm_a [:] = tmp_a
710
717
711
718
712
719
def indices (
@@ -758,6 +765,13 @@ def indices(
758
765
with grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1)
759
766
with dimensions[i] in the i-th place.
760
767
768
+ See Also
769
+ --------
770
+ :obj:`dpnp.mgrid` : Return a dense multi-dimensional “meshgrid”.
771
+ :obj:`dpnp.ogrid` : Return an open multi-dimensional “meshgrid”.
772
+ :obj:`dpnp.meshgrid` : Return a tuple of coordinate matrices from
773
+ coordinate vectors.
774
+
761
775
Examples
762
776
--------
763
777
>>> import dpnp as np
@@ -800,6 +814,7 @@ def indices(
800
814
dimensions = tuple (dimensions )
801
815
n = len (dimensions )
802
816
shape = (1 ,) * n
817
+
803
818
if sparse :
804
819
res = ()
805
820
else :
@@ -810,6 +825,7 @@ def indices(
810
825
usm_type = usm_type ,
811
826
sycl_queue = sycl_queue ,
812
827
)
828
+
813
829
for i , dim in enumerate (dimensions ):
814
830
idx = dpnp .arange (
815
831
dim ,
@@ -818,6 +834,7 @@ def indices(
818
834
usm_type = usm_type ,
819
835
sycl_queue = sycl_queue ,
820
836
).reshape (shape [:i ] + (dim ,) + shape [i + 1 :])
837
+
821
838
if sparse :
822
839
res = res + (idx ,)
823
840
else :
@@ -927,10 +944,12 @@ def nonzero(a):
927
944
"""
928
945
Return the indices of the elements that are non-zero.
929
946
930
- Returns a tuple of arrays, one for each dimension of `a`,
931
- containing the indices of the non-zero elements in that
932
- dimension. The values in `a` are always tested and returned in
933
- row-major, C-style order.
947
+ Returns a tuple of arrays, one for each dimension of `a`, containing
948
+ the indices of the non-zero elements in that dimension. The values in `a`
949
+ are always tested and returned in row-major, C-style order.
950
+
951
+ To group the indices by element, rather than dimension, use
952
+ :obj:`dpnp.argwhere`, which returns a row for each non-zero element.
934
953
935
954
For full documentation refer to :obj:`numpy.nonzero`.
936
955
@@ -1005,9 +1024,9 @@ def nonzero(a):
1005
1024
1006
1025
"""
1007
1026
1008
- usx_a = dpnp .get_usm_ndarray (a )
1027
+ usm_a = dpnp .get_usm_ndarray (a )
1009
1028
return tuple (
1010
- dpnp_array ._create_from_usm_ndarray (y ) for y in dpt .nonzero (usx_a )
1029
+ dpnp_array ._create_from_usm_ndarray (y ) for y in dpt .nonzero (usm_a )
1011
1030
)
1012
1031
1013
1032
@@ -1139,47 +1158,60 @@ def put(a, ind, v, /, *, axis=None, mode="wrap"):
1139
1158
1140
1159
"""
1141
1160
1142
- dpnp .check_supported_arrays_type (a )
1143
-
1144
- if not dpnp .is_supported_array_type (ind ):
1145
- ind = dpnp .asarray (
1146
- ind , dtype = dpnp .intp , sycl_queue = a .sycl_queue , usm_type = a .usm_type
1147
- )
1148
- elif not dpnp .issubdtype (ind .dtype , dpnp .integer ):
1149
- ind = dpnp .astype (ind , dtype = dpnp .intp , casting = "safe" )
1150
- ind = dpnp .ravel (ind )
1151
-
1152
- if not dpnp .is_supported_array_type (v ):
1153
- v = dpnp .asarray (
1154
- v , dtype = a .dtype , sycl_queue = a .sycl_queue , usm_type = a .usm_type
1155
- )
1156
- if v .size == 0 :
1157
- return
1161
+ usm_a = dpnp .get_usm_ndarray (a )
1158
1162
1159
1163
if not (axis is None or isinstance (axis , int )):
1160
1164
raise TypeError (f"`axis` must be of integer type, got { type (axis )} " )
1161
1165
1162
- in_a = a
1163
- if axis is None and a .ndim > 1 :
1164
- a = dpnp .ravel (in_a )
1165
-
1166
1166
if mode not in ("wrap" , "clip" ):
1167
1167
raise ValueError (
1168
1168
f"clipmode must be one of 'clip' or 'wrap' (got '{ mode } ')"
1169
1169
)
1170
1170
1171
- usm_a = dpnp .get_usm_ndarray (a )
1172
- usm_ind = dpnp .get_usm_ndarray (ind )
1173
- usm_v = dpnp .get_usm_ndarray (v )
1171
+ usm_v = dpnp .as_usm_ndarray (
1172
+ v ,
1173
+ dtype = usm_a .dtype ,
1174
+ usm_type = usm_a .usm_type ,
1175
+ sycl_queue = usm_a .sycl_queue ,
1176
+ )
1177
+ if usm_v .size == 0 :
1178
+ return
1179
+
1180
+ usm_ind = dpnp .as_usm_ndarray (
1181
+ ind ,
1182
+ dtype = dpnp .intp ,
1183
+ usm_type = usm_a .usm_type ,
1184
+ sycl_queue = usm_a .sycl_queue ,
1185
+ )
1186
+
1187
+ if usm_ind .ndim != 1 :
1188
+ # dpt.put supports only 1-D array of indices
1189
+ usm_ind = dpt .reshape (usm_ind , - 1 , copy = False )
1190
+
1191
+ if not dpnp .issubdtype (usm_ind .dtype , dpnp .integer ):
1192
+ # dpt.put supports only integer dtype for array of indices
1193
+ usm_ind = dpt .astype (usm_ind , dpnp .intp , casting = "safe" )
1194
+
1195
+ in_usm_a = usm_a
1196
+ if axis is None and usm_a .ndim > 1 :
1197
+ usm_a = dpt .reshape (usm_a , - 1 )
1198
+
1174
1199
dpt .put (usm_a , usm_ind , usm_v , axis = axis , mode = mode )
1175
- if in_a is not a :
1176
- in_a [:] = a .reshape (in_a .shape , copy = False )
1200
+ if in_usm_a . _pointer != usm_a . _pointer : # pylint: disable=protected-access
1201
+ in_usm_a [:] = dpt .reshape (usm_a , in_usm_a .shape , copy = False )
1177
1202
1178
1203
1179
1204
def put_along_axis (a , ind , values , axis ):
1180
1205
"""
1181
1206
Put values into the destination array by matching 1d index and data slices.
1182
1207
1208
+ This iterates over matching 1d slices oriented along the specified axis in
1209
+ the index and data arrays, and uses the former to place values into the
1210
+ latter. These slices can be different lengths.
1211
+
1212
+ Functions returning an index along an `axis`, like :obj:`dpnp.argsort` and
1213
+ :obj:`dpnp.argpartition`, produce suitable indices for this function.
1214
+
1183
1215
For full documentation refer to :obj:`numpy.put_along_axis`.
1184
1216
1185
1217
Parameters
@@ -1415,6 +1447,13 @@ def take_along_axis(a, indices, axis):
1415
1447
"""
1416
1448
Take values from the input array by matching 1d index and data slices.
1417
1449
1450
+ This iterates over matching 1d slices oriented along the specified axis in
1451
+ the index and data arrays, and uses the former to look up values in the
1452
+ latter. These slices can be different lengths.
1453
+
1454
+ Functions returning an index along an `axis`, like :obj:`dpnp.argsort` and
1455
+ :obj:`dpnp.argpartition`, produce suitable indices for this function.
1456
+
1418
1457
For full documentation refer to :obj:`numpy.take_along_axis`.
1419
1458
1420
1459
Parameters
@@ -1428,7 +1467,7 @@ def take_along_axis(a, indices, axis):
1428
1467
axis : int
1429
1468
The axis to take 1d slices along. If axis is ``None``, the input
1430
1469
array is treated as if it had first been flattened to 1d,
1431
- for consistency with ` sort` and ` argsort`.
1470
+ for consistency with :obj:`dpnp. sort` and :obj:`dpnp. argsort`.
1432
1471
1433
1472
Returns
1434
1473
-------
0 commit comments