Skip to content

Commit 05af730

Browse files
Update dpnp.diagonal() (#1817)
* Update dpnp.diagonal to support offset and axes * Update cupy tests for dpnp.diagonal() * Update dpnp tests for dpnp.diagonal() * Remove pytest.skip for a closed issue * Extend test cases in test_diagonal_errors * Update dpnp.ndarray.diagonal * Update docstrings for dpnp.diagonal * Address remarks * Simplify getting axes_order * Add cupy test_indexing.py to test scope * Remove TODO from dpnp_det/dpnp_slogdet * Improve test coverage * Update the calculation of result parameters
1 parent 5d94ca8 commit 05af730

File tree

12 files changed

+218
-92
lines changed

12 files changed

+218
-92
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ env:
3838
test_umath.py
3939
test_usm_type.py
4040
third_party/cupy/core_tests
41+
third_party/cupy/indexing_tests/test_indexing.py
4142
third_party/cupy/linalg_tests/test_decomposition.py
4243
third_party/cupy/linalg_tests/test_norms.py
4344
third_party/cupy/linalg_tests/test_product.py

dpnp/dpnp_algo/dpnp_algo_indexing.pxi

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ and the rest of the library
3838
__all__ += [
3939
"dpnp_choose",
4040
"dpnp_diag_indices",
41-
"dpnp_diagonal",
4241
"dpnp_fill_diagonal",
4342
"dpnp_putmask",
4443
"dpnp_select",

dpnp/dpnp_array.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -744,15 +744,26 @@ def cumsum(self, axis=None, dtype=None, out=None):
744744

745745
# 'data',
746746

747-
def diagonal(input, offset=0, axis1=0, axis2=1):
747+
def diagonal(self, offset=0, axis1=0, axis2=1):
748748
"""
749749
Return specified diagonals.
750750
751751
Refer to :obj:`dpnp.diagonal` for full documentation.
752752
753+
See Also
754+
--------
755+
:obj:`dpnp.diagonal` : Equivalent function.
756+
757+
Examples
758+
--------
759+
>>> import dpnp as np
760+
>>> a = np.arange(4).reshape(2,2)
761+
>>> a.diagonal()
762+
array([0, 3])
763+
753764
"""
754765

755-
return dpnp.diagonal(input, offset, axis1, axis2)
766+
return dpnp.diagonal(self, offset=offset, axis1=axis1, axis2=axis2)
756767

757768
def dot(self, b, out=None):
758769
"""

dpnp/dpnp_iface_indexing.py

Lines changed: 141 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
from .dpnp_algo import (
4848
dpnp_choose,
4949
dpnp_diag_indices,
50-
dpnp_diagonal,
5150
dpnp_fill_diagonal,
5251
dpnp_putmask,
5352
dpnp_select,
@@ -276,34 +275,156 @@ def diag_indices_from(x1):
276275
return call_origin(numpy.diag_indices_from, x1)
277276

278277

279-
def diagonal(x1, offset=0, axis1=0, axis2=1):
278+
def diagonal(a, offset=0, axis1=0, axis2=1):
280279
"""
281280
Return specified diagonals.
282281
283282
For full documentation refer to :obj:`numpy.diagonal`.
284283
285-
Limitations
286-
-----------
287-
Input array is supported as :obj:`dpnp.ndarray`.
288-
Parameters `axis1` and `axis2` are supported only with default values.
289-
Otherwise the function will be executed sequentially on CPU.
284+
Parameters
285+
----------
286+
a : {dpnp.ndarray, usm_ndarray}
287+
Array from which the diagonals are taken.
288+
offset : int, optional
289+
Offset of the diagonal from the main diagonal. Can be positive or
290+
negative. Defaults to main diagonal (``0``).
291+
axis1 : int, optional
292+
Axis to be used as the first axis of the 2-D sub-arrays from which
293+
the diagonals should be taken. Defaults to first axis (``0``).
294+
axis2 : int, optional
295+
Axis to be used as the second axis of the 2-D sub-arrays from
296+
which the diagonals should be taken. Defaults to second axis (``1``).
297+
298+
Returns
299+
-------
300+
array_of_diagonals : dpnp.ndarray
301+
If `a` is 2-D, then a 1-D array containing the diagonal and of the
302+
same type as `a` is returned.
303+
If ``a.ndim > 2``, then the dimensions specified by `axis1` and `axis2`
304+
are removed, and a new axis inserted at the end corresponding to the
305+
diagonal.
306+
307+
See Also
308+
--------
309+
:obj:`dpnp.diag` : Extract a diagonal or construct a diagonal array.
310+
:obj:`dpnp.diagflat` : Create a two-dimensional array
311+
with the flattened input as a diagonal.
312+
:obj:`dpnp.trace` : Return the sum along diagonals of the array.
313+
314+
Examples
315+
--------
316+
>>> import dpnp as np
317+
>>> a = np.arange(4).reshape(2,2)
318+
>>> a
319+
array([[0, 1],
320+
[2, 3]])
321+
>>> a.diagonal()
322+
array([0, 3])
323+
>>> a.diagonal(1)
324+
array([1])
325+
326+
A 3-D example:
327+
328+
>>> a = np.arange(8).reshape(2,2,2)
329+
>>> a
330+
array([[[0, 1],
331+
[2, 3]],
332+
[[4, 5],
333+
[6, 7]]])
334+
>>> a.diagonal(0, # Main diagonals of two arrays created by skipping
335+
... 0, # across the outer(left)-most axis last and
336+
... 1) # the "middle" (row) axis first.
337+
array([[0, 6],
338+
[1, 7]])
339+
340+
The sub-arrays whose main diagonals we just obtained; note that each
341+
corresponds to fixing the right-most (column) axis, and that the
342+
diagonals are "packed" in rows.
343+
344+
>>> a[:,:,0] # main diagonal is [0 6]
345+
array([[0, 2],
346+
[4, 6]])
347+
>>> a[:,:,1] # main diagonal is [1 7]
348+
array([[1, 3],
349+
[5, 7]])
350+
351+
The anti-diagonal can be obtained by reversing the order of elements
352+
using either `dpnp.flipud` or `dpnp.fliplr`.
353+
354+
>>> a = np.arange(9).reshape(3, 3)
355+
>>> a
356+
array([[0, 1, 2],
357+
[3, 4, 5],
358+
[6, 7, 8]])
359+
>>> np.fliplr(a).diagonal() # Horizontal flip
360+
array([2, 4, 6])
361+
>>> np.flipud(a).diagonal() # Vertical flip
362+
array([6, 4, 2])
363+
364+
Note that the order in which the diagonal is retrieved varies depending
365+
on the flip function.
290366
291367
"""
292368

293-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
294-
if x1_desc:
295-
if not isinstance(offset, int):
296-
pass
297-
elif offset < 0:
298-
pass
299-
elif axis1 != 0:
300-
pass
301-
elif axis2 != 1:
302-
pass
303-
else:
304-
return dpnp_diagonal(x1_desc, offset).get_pyobj()
369+
dpnp.check_supported_arrays_type(a)
370+
a_ndim = a.ndim
305371

306-
return call_origin(numpy.diagonal, x1, offset, axis1, axis2)
372+
if a_ndim < 2:
373+
raise ValueError("diag requires an array of at least two dimensions")
374+
375+
if not isinstance(offset, int):
376+
raise TypeError(
377+
f"`offset` must be an integer data type, but got {type(offset)}"
378+
)
379+
380+
axis1 = normalize_axis_index(axis1, a_ndim)
381+
axis2 = normalize_axis_index(axis2, a_ndim)
382+
383+
if axis1 == axis2:
384+
raise ValueError("`axis1` and `axis2` cannot be the same")
385+
386+
# get list of the order of all axes excluding the two target axes
387+
axes_order = [i for i in range(a_ndim) if i not in [axis1, axis2]]
388+
389+
# transpose the input array to put the target axes at the end
390+
# to simplify diagonal extraction
391+
if offset >= 0:
392+
a = dpnp.transpose(a, axes_order + [axis1, axis2])
393+
else:
394+
a = dpnp.transpose(a, axes_order + [axis2, axis1])
395+
offset = -offset
396+
397+
a_shape = a.shape
398+
a_straides = a.strides
399+
n, m = a_shape[-2:]
400+
st_n, st_m = a_straides[-2:]
401+
# pylint: disable=W0212
402+
a_element_offset = a.get_array()._element_offset
403+
404+
# Compute shape, strides and offset of the resulting diagonal array
405+
# based on the input offset
406+
if offset == 0:
407+
out_shape = a_shape[:-2] + (min(n, m),)
408+
out_strides = a_straides[:-2] + (st_n + st_m,)
409+
out_offset = a_element_offset
410+
elif 0 < offset < m:
411+
out_shape = a_shape[:-2] + (min(n, m - offset),)
412+
out_strides = a_straides[:-2] + (st_n + st_m,)
413+
out_offset = a_element_offset + st_m * offset
414+
else:
415+
out_shape = a_shape[:-2] + (0,)
416+
out_strides = a_straides[:-2] + (1,)
417+
out_offset = a_element_offset
418+
419+
return dpnp_array._create_from_usm_ndarray(
420+
dpt.usm_ndarray(
421+
out_shape,
422+
dtype=a.dtype,
423+
buffer=a.get_array(),
424+
strides=out_strides,
425+
offset=out_offset,
426+
)
427+
)
307428

308429

309430
def extract(condition, x):

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -953,11 +953,7 @@ def dpnp_det(a):
953953

954954
lu, ipiv, dev_info = _lu_factor(a, res_type)
955955

956-
# Transposing 'lu' to swap the last two axes for compatibility
957-
# with 'dpnp.diagonal' as it does not support 'axis1' and 'axis2' arguments.
958-
# TODO: Replace with 'dpnp.diagonal(lu, axis1=-2, axis2=-1)' when supported.
959-
lu_transposed = lu.transpose(-2, -1, *range(lu.ndim - 2))
960-
diag = dpnp.diagonal(lu_transposed)
956+
diag = dpnp.diagonal(lu, axis1=-2, axis2=-1)
961957

962958
det = dpnp.prod(dpnp.abs(diag), axis=-1)
963959

@@ -2112,11 +2108,7 @@ def dpnp_slogdet(a):
21122108

21132109
lu, ipiv, dev_info = _lu_factor(a, res_type)
21142110

2115-
# Transposing 'lu' to swap the last two axes for compatibility
2116-
# with 'dpnp.diagonal' as it does not support 'axis1' and 'axis2' arguments.
2117-
# TODO: Replace with 'dpnp.diagonal(lu, axis1=-2, axis2=-1)' when supported.
2118-
lu_transposed = lu.transpose(-2, -1, *range(lu.ndim - 2))
2119-
diag = dpnp.diagonal(lu_transposed)
2111+
diag = dpnp.diagonal(lu, axis1=-2, axis2=-1)
21202112

21212113
logdet = dpnp.log(dpnp.abs(diag)).sum(axis=-1)
21222114

tests/skipped_tests.tbl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compr
148148
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim_no_axis
149149
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_axis
150150
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_bool
151-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal
152-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_invalid1
153-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_invalid2
154-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_negative1
155-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_negative2
156-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_negative3
157-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_negative4
158-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_negative5
159151
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
160152
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
161153
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,14 +210,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_
210210
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_broadcastable_complex
211211
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_odd_shaped_non_broadcastable
212212
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_type_error_condlist
213-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal
214-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_invalid1
215-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_invalid2
216-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_negative1
217-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_negative2
218-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_negative3
219-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_negative4
220-
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_diagonal_negative5
221213

222214
tests/third_party/cupy/indexing_tests/test_insert.py::TestDiagIndices_param_0_{n=2, ndim=2}::test_diag_indices
223215
tests/third_party/cupy/indexing_tests/test_insert.py::TestDiagIndices_param_1_{n=2, ndim=3}::test_diag_indices

tests/test_indexing.py

Lines changed: 59 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -322,48 +322,68 @@ def test_choose():
322322
assert_array_equal(expected, result)
323323

324324

325-
@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_bool=True))
326-
@pytest.mark.parametrize("offset", [0, 1], ids=["0", "1"])
327-
@pytest.mark.parametrize(
328-
"array",
329-
[
330-
[[0, 0], [0, 0]],
331-
[[1, 2], [1, 2]],
332-
[[1, 2], [3, 4]],
333-
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
334-
[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]],
335-
[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]],
336-
[
337-
[[[1, 2], [3, 4]], [[1, 2], [2, 1]]],
338-
[[[1, 3], [3, 1]], [[0, 1], [1, 3]]],
339-
],
340-
[
341-
[[[1, 2, 3], [3, 4, 5]], [[1, 2, 3], [2, 1, 0]]],
342-
[[[1, 3, 5], [3, 1, 0]], [[0, 1, 2], [1, 3, 4]]],
325+
class TestDiagonal:
326+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
327+
@pytest.mark.parametrize("offset", [-3, -1, 0, 1, 3])
328+
@pytest.mark.parametrize(
329+
"shape",
330+
[(2, 2), (3, 3), (2, 5), (3, 2, 2), (2, 2, 2, 2), (2, 2, 2, 3)],
331+
ids=[
332+
"(2,2)",
333+
"(3,3)",
334+
"(2,5)",
335+
"(3,2,2)",
336+
"(2,2,2,2)",
337+
"(2,2,2,3)",
343338
],
339+
)
340+
def test_diagonal_offset(self, shape, dtype, offset):
341+
a = numpy.arange(numpy.prod(shape), dtype=dtype).reshape(shape)
342+
a_dp = dpnp.array(a)
343+
expected = numpy.diagonal(a, offset)
344+
result = dpnp.diagonal(a_dp, offset)
345+
assert_array_equal(expected, result)
346+
347+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
348+
@pytest.mark.parametrize(
349+
"shape, axis_pairs",
344350
[
345-
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
346-
[[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]],
351+
((3, 4), [(0, 1), (1, 0)]),
352+
((3, 4, 5), [(0, 1), (1, 2), (0, 2)]),
353+
((4, 3, 5, 2), [(0, 1), (1, 2), (2, 3), (0, 3)]),
347354
],
348-
],
349-
ids=[
350-
"[[0, 0], [0, 0]]",
351-
"[[1, 2], [1, 2]]",
352-
"[[1, 2], [3, 4]]",
353-
"[[0, 1, 2], [3, 4, 5], [6, 7, 8]]",
354-
"[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]",
355-
"[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]]",
356-
"[[[[1, 2], [3, 4]], [[1, 2], [2, 1]]], [[[1, 3], [3, 1]], [[0, 1], [1, 3]]]]",
357-
"[[[[1, 2, 3], [3, 4, 5]], [[1, 2, 3], [2, 1, 0]]], [[[1, 3, 5], [3, 1, 0]], [[0, 1, 2], [1, 3, 4]]]]",
358-
"[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]]]",
359-
],
360-
)
361-
def test_diagonal(array, arr_dtype, offset):
362-
a = numpy.array(array, dtype=arr_dtype)
363-
ia = dpnp.array(a)
364-
expected = numpy.diagonal(a, offset)
365-
result = dpnp.diagonal(ia, offset)
366-
assert_array_equal(expected, result)
355+
)
356+
def test_diagonal_axes(self, shape, axis_pairs, dtype):
357+
a = numpy.arange(numpy.prod(shape), dtype=dtype).reshape(shape)
358+
a_dp = dpnp.array(a)
359+
for axis1, axis2 in axis_pairs:
360+
expected = numpy.diagonal(a, axis1=axis1, axis2=axis2)
361+
result = dpnp.diagonal(a_dp, axis1=axis1, axis2=axis2)
362+
assert_array_equal(expected, result)
363+
364+
def test_diagonal_errors(self):
365+
a = dpnp.arange(12).reshape(3, 4)
366+
367+
# unsupported type
368+
a_np = dpnp.asnumpy(a)
369+
assert_raises(TypeError, dpnp.diagonal, a_np)
370+
371+
# a.ndim < 2
372+
a_ndim_1 = a.flatten()
373+
assert_raises(ValueError, dpnp.diagonal, a_ndim_1)
374+
375+
# unsupported type `offset`
376+
assert_raises(TypeError, dpnp.diagonal, a, offset=1.0)
377+
assert_raises(TypeError, dpnp.diagonal, a, offset=[0])
378+
379+
# axes are out of bounds
380+
assert_raises(numpy.AxisError, a.diagonal, axis1=0, axis2=5)
381+
assert_raises(numpy.AxisError, a.diagonal, axis1=5, axis2=0)
382+
assert_raises(numpy.AxisError, a.diagonal, axis1=5, axis2=5)
383+
384+
# same axes
385+
assert_raises(ValueError, a.diagonal, axis1=1, axis2=1)
386+
assert_raises(ValueError, a.diagonal, axis1=1, axis2=-1)
367387

368388

369389
@pytest.mark.parametrize("arr_dtype", get_all_dtypes())

0 commit comments

Comments
 (0)