Skip to content

Commit 4ef2157

Browse files
Improve performance for dpnp.diag() (#1822)
* Update dpnp.diag to improve perfomance * Use dpnp.zero_like() instead of dpnp.zeros() * Update doctrings for dpnp.diagflat/eye/identity * Update cupy test_matrix.py
1 parent 793cf5e commit 4ef2157

File tree

2 files changed

+44
-51
lines changed

2 files changed

+44
-51
lines changed

dpnp/dpnp_iface_arraycreation.py

Lines changed: 39 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -866,9 +866,14 @@ def diag(v, /, k=0, *, device=None, usm_type=None, sycl_queue=None):
866866
v : array_like
867867
Input data, in any form that can be converted to an array. This
868868
includes scalars, lists, lists of tuples, tuples, tuples of tuples,
869-
tuples of lists, and ndarrays. If `v` is a 2-D array, return a copy of
870-
its k-th diagonal. If `v` is a 1-D array, return a 2-D array with `v`
871-
on the k-th diagonal.
869+
tuples of lists, and ndarrays.
870+
If `v` is a 1-D array, return a 2-D array with `v`
871+
on the `k`-th diagonal.
872+
If `v` is a 2-D array and is an instance of
873+
{dpnp.ndarray, usm_ndarray}, then:
874+
- If `device`, `usm_type`, and `sycl_queue` are set to their
875+
default values, returns a read/write view of its k-th diagonal.
876+
- Otherwise, returns a copy of its k-th diagonal.
872877
k : int, optional
873878
Diagonal in question. The default is 0. Use k > 0 for diagonals above
874879
the main diagonal, and k < 0 for diagonals below the main diagonal.
@@ -894,79 +899,62 @@ def diag(v, /, k=0, *, device=None, usm_type=None, sycl_queue=None):
894899
--------
895900
:obj:`diagonal` : Return specified diagonals.
896901
:obj:`diagflat` : Create a 2-D array with the flattened input as a diagonal.
897-
:obj:`trace` : Return sum along diagonals.
898-
:obj:`triu` : Return upper triangle of an array.
899-
:obj:`tril` : Return lower triangle of an array.
902+
:obj:`trace` : Return the sum along diagonals of the array.
903+
:obj:`triu` : Upper triangle of an array.
904+
:obj:`tril` : Lower triangle of an array.
900905
901906
Examples
902907
--------
903908
>>> import dpnp as np
904-
>>> x0 = np.arange(9).reshape((3, 3))
905-
>>> x0
909+
>>> x = np.arange(9).reshape((3, 3))
910+
>>> x
906911
array([[0, 1, 2],
907912
[3, 4, 5],
908913
[6, 7, 8]])
909914
910-
>>> np.diag(x0)
915+
>>> np.diag(x)
911916
array([0, 4, 8])
912-
>>> np.diag(x0, k=1)
917+
>>> np.diag(x, k=1)
913918
array([1, 5])
914-
>>> np.diag(x0, k=-1)
919+
>>> np.diag(x, k=-1)
915920
array([3, 7])
916921
917-
>>> np.diag(np.diag(x0))
922+
>>> np.diag(np.diag(x))
918923
array([[0, 0, 0],
919924
[0, 4, 0],
920925
[0, 0, 8]])
921926
922927
Creating an array on a different device or with a specified usm_type
923928
924-
>>> x = np.diag(x0) # default case
925-
>>> x, x.device, x.usm_type
929+
>>> res = np.diag(x) # default case
930+
>>> res, res.device, res.usm_type
926931
(array([0, 4, 8]), Device(level_zero:gpu:0), 'device')
927932
928-
>>> y = np.diag(x0, device="cpu")
929-
>>> y, y.device, y.usm_type
933+
>>> res_cpu = np.diag(x, device="cpu")
934+
>>> res_cpu, res_cpu.device, res_cpu.usm_type
930935
(array([0, 4, 8]), Device(opencl:cpu:0), 'device')
931936
932-
>>> z = np.diag(x0, usm_type="host")
933-
>>> z, z.device, z.usm_type
937+
>>> res_host = np.diag(x, usm_type="host")
938+
>>> res_host, res_host.device, res_host.usm_type
934939
(array([0, 4, 8]), Device(level_zero:gpu:0), 'host')
935940
936941
"""
937942

938943
if not isinstance(k, int):
939944
raise TypeError(f"An integer is required, but got {type(k)}")
940945

941-
v = dpnp.asarray(v, device=device, usm_type=usm_type, sycl_queue=sycl_queue)
946+
v = dpnp.asanyarray(
947+
v, device=device, usm_type=usm_type, sycl_queue=sycl_queue
948+
)
942949

943-
init0 = max(0, -k)
944-
init1 = max(0, k)
945950
if v.ndim == 1:
946951
size = v.shape[0] + abs(k)
947-
m = dpnp.zeros(
948-
(size, size),
949-
dtype=v.dtype,
950-
usm_type=v.usm_type,
951-
sycl_queue=v.sycl_queue,
952-
)
953-
for i in range(v.shape[0]):
954-
m[(init0 + i), init1 + i] = v[i]
955-
return m
952+
ret = dpnp.zeros_like(v, shape=(size, size))
953+
ret.diagonal(k)[:] = v
954+
return ret
956955

957956
if v.ndim == 2:
958-
size = max(
959-
0, min(v.shape[0], v.shape[0] + k, v.shape[1], v.shape[1] - k)
960-
)
961-
m = dpnp.zeros(
962-
(size,),
963-
dtype=v.dtype,
964-
usm_type=v.usm_type,
965-
sycl_queue=v.sycl_queue,
966-
)
967-
for i in range(size):
968-
m[i] = v[(init0 + i), init1 + i]
969-
return m
957+
return v.diagonal(k)
970958

971959
raise ValueError("Input must be a 1-D or 2-D array.")
972960

@@ -1008,9 +996,9 @@ def diagflat(v, /, k=0, *, device=None, usm_type=None, sycl_queue=None):
1008996
1009997
See Also
1010998
--------
1011-
:obj:`diag` : Return the extracted diagonal or constructed diagonal array.
1012-
:obj:`diagonal` : Return specified diagonals.
1013-
:obj:`trace` : Return sum along diagonals.
999+
:obj:`dpnp.diag` : Extract a diagonal or construct a diagonal array.
1000+
:obj:`dpnp.diagonal` : Return specified diagonals.
1001+
:obj:`dpnp.trace` : Return sum along diagonals.
10141002
10151003
Examples
10161004
--------
@@ -1324,6 +1312,11 @@ def eye(
13241312
Parameter `like` is supported only with default value ``None``.
13251313
Otherwise, the function raises `NotImplementedError` exception.
13261314
1315+
See Also
1316+
--------
1317+
:obj:`dpnp.identity` : Return the identity array.
1318+
:obj:`dpnp.diag` : Extract a diagonal or construct a diagonal array.
1319+
13271320
Examples
13281321
--------
13291322
>>> import dpnp as np
@@ -2264,7 +2257,7 @@ def identity(
22642257
:obj:`dpnp.eye` : Return a 2-D array with ones on the diagonal and zeros
22652258
elsewhere.
22662259
:obj:`dpnp.ones` : Return a new array setting values to one.
2267-
:obj:`dpnp.diag` : Return diagonal 2-D array from an input 1-D array.
2260+
:obj:`dpnp.diag` : Extract a diagonal or construct a diagonal array.
22682261
22692262
Examples
22702263
--------

tests/third_party/cupy/creation_tests/test_matrix.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,35 +27,35 @@ def test_diag3(self, xp):
2727
def test_diag_extraction_from_nested_list(self, xp):
2828
a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
2929
r = xp.diag(a, 1)
30-
self.assertIsInstance(r, xp.ndarray)
30+
assert isinstance(r, xp.ndarray)
3131
return r
3232

3333
@testing.numpy_cupy_array_equal()
3434
def test_diag_extraction_from_nested_tuple(self, xp):
3535
a = ((1, 2, 3), (4, 5, 6), (7, 8, 9))
3636
r = xp.diag(a, -1)
37-
self.assertIsInstance(r, xp.ndarray)
37+
assert isinstance(r, xp.ndarray)
3838
return r
3939

4040
@testing.numpy_cupy_array_equal()
4141
def test_diag_construction(self, xp):
4242
a = testing.shaped_arange((3,), xp)
4343
r = xp.diag(a)
44-
self.assertIsInstance(r, xp.ndarray)
44+
assert isinstance(r, xp.ndarray)
4545
return r
4646

4747
@testing.numpy_cupy_array_equal()
4848
def test_diag_construction_from_list(self, xp):
4949
a = [1, 2, 3]
5050
r = xp.diag(a)
51-
self.assertIsInstance(r, xp.ndarray)
51+
assert isinstance(r, xp.ndarray)
5252
return r
5353

5454
@testing.numpy_cupy_array_equal()
5555
def test_diag_construction_from_tuple(self, xp):
5656
a = (1, 2, 3)
5757
r = xp.diag(a)
58-
self.assertIsInstance(r, xp.ndarray)
58+
assert isinstance(r, xp.ndarray)
5959
return r
6060

6161
def test_diag_scaler(self):

0 commit comments

Comments
 (0)