Skip to content

Commit f6971d3

Browse files
committed
Return named tuple for eig, eigh, qr, slogdet, svd functions
1 parent 0c455a6 commit f6971d3

File tree

2 files changed

+89
-38
lines changed

2 files changed

+89
-38
lines changed

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,18 @@
3939
# pylint: disable=invalid-name
4040
# pylint: disable=no-member
4141

42+
from typing import NamedTuple
43+
4244
import numpy
4345
from dpctl.tensor._numpy_helper import normalize_axis_tuple
4446

4547
import dpnp
4648

4749
from .dpnp_utils_linalg import (
50+
EighResult,
51+
QRResult,
52+
SlogdetResult,
53+
SVDResult,
4854
assert_2d,
4955
assert_stacked_2d,
5056
assert_stacked_square,
@@ -66,6 +72,11 @@
6672
)
6773

6874
__all__ = [
75+
"EigResult",
76+
"EighResult",
77+
"QRResult",
78+
"SlogdetResult",
79+
"SVDResult",
6980
"cholesky",
7081
"cond",
7182
"cross",
@@ -100,6 +111,12 @@
100111
]
101112

102113

114+
# pylint:disable=missing-class-docstring
115+
class EigResult(NamedTuple):
116+
eigenvalues: dpnp.ndarray
117+
eigenvectors: dpnp.ndarray
118+
119+
103120
def cholesky(a, /, *, upper=False):
104121
"""
105122
Cholesky decomposition.
@@ -532,7 +549,7 @@ def eig(a):
532549
# Since geev function from OneMKL LAPACK is not implemented yet,
533550
# use NumPy for this calculation.
534551
w_np, v_np = numpy.linalg.eig(dpnp.asnumpy(a))
535-
return (
552+
return EigResult(
536553
dpnp.array(w_np, sycl_queue=a_sycl_queue, usm_type=a_usm_type),
537554
dpnp.array(v_np, sycl_queue=a_sycl_queue, usm_type=a_usm_type),
538555
)

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
# pylint: disable=protected-access
3939
# pylint: disable=useless-import-alias
4040

41+
from typing import NamedTuple
42+
4143
import dpctl.tensor._tensor_impl as ti
4244
import dpctl.utils as dpu
4345
import numpy
@@ -50,6 +52,10 @@
5052
from dpnp.linalg import LinAlgError as LinAlgError
5153

5254
__all__ = [
55+
"EighResult",
56+
"QRResult",
57+
"SlogdetResult",
58+
"SVDResult",
5359
"assert_2d",
5460
"assert_stacked_2d",
5561
"assert_stacked_square",
@@ -70,6 +76,29 @@
7076
"dpnp_svd",
7177
]
7278

79+
80+
# pylint:disable=missing-class-docstring
81+
class EighResult(NamedTuple):
82+
eigenvalues: dpnp.ndarray
83+
eigenvectors: dpnp.ndarray
84+
85+
86+
class QRResult(NamedTuple):
87+
Q: dpnp.ndarray
88+
R: dpnp.ndarray
89+
90+
91+
class SlogdetResult(NamedTuple):
92+
sign: dpnp.ndarray
93+
logabsdet: dpnp.ndarray
94+
95+
96+
class SVDResult(NamedTuple):
97+
U: dpnp.ndarray
98+
S: dpnp.ndarray
99+
Vh: dpnp.ndarray
100+
101+
73102
_jobz = {"N": 0, "V": 1}
74103
_upper_lower = {"U": 0, "L": 1}
75104

@@ -162,7 +191,7 @@ def _batched_eigh(a, UPLO, eigen_mode, w_type, v_type):
162191
# Convert to contiguous to align with NumPy
163192
if a_orig_order == "C":
164193
v = dpnp.ascontiguousarray(v)
165-
return w, v
194+
return EighResult(w, v)
166195
return w
167196

168197

@@ -476,7 +505,7 @@ def _batched_qr(a, mode="reduced"):
476505

477506
r = _triu_inplace(r)
478507

479-
return (
508+
return QRResult(
480509
q.reshape(batch_shape + q.shape[-2:]),
481510
r.reshape(batch_shape + r.shape[-2:]),
482511
)
@@ -632,7 +661,7 @@ def _batched_svd(
632661
u = dpnp.ascontiguousarray(u)
633662
vt = dpnp.ascontiguousarray(vt)
634663
# Swap `u` and `vt` for transposed input to restore correct order
635-
return (vt, s, u) if trans_flag else (u, s, vt)
664+
return SVDResult(vt, s, u) if trans_flag else SVDResult(u, s, vt)
636665
return s
637666

638667

@@ -819,9 +848,9 @@ def _hermitian_svd(a, compute_uv):
819848
# but dpnp.linalg.eigh returns s sorted ascending so we re-order
820849
# the eigenvalues and related arrays to have the correct order
821850
if compute_uv:
822-
s, u = dpnp.linalg.eigh(a)
823-
sgn = dpnp.sign(s)
824-
s = dpnp.absolute(s)
851+
s, u = s = dpnp_eigh(a, eigen_mode="V")
852+
sgn = dpnp.sign(s, out=s)
853+
s = dpnp.abs(s, out=s)
825854
sidx = dpnp.argsort(s)[..., ::-1]
826855
# Rearrange the signs according to sorted indices
827856
sgn = dpnp.take_along_axis(sgn, sidx, axis=-1)
@@ -832,11 +861,10 @@ def _hermitian_svd(a, compute_uv):
832861
# Singular values are unsigned, move the sign into v
833862
# Compute V^T adjusting for the sign and conjugating
834863
vt = dpnp.transpose(u * sgn[..., None, :]).conjugate()
835-
return u, s, vt
864+
return SVDResult(u, s, vt)
836865

837-
# TODO: use dpnp.linalg.eighvals when it is updated
838-
s, _ = dpnp.linalg.eigh(a)
839-
s = dpnp.abs(s)
866+
s = dpnp_eigh(a, eigen_mode="N")
867+
s = dpnp.abs(s, out=s)
840868
return dpnp.sort(s)[..., ::-1]
841869

842870

@@ -1423,7 +1451,7 @@ def _zero_batched_qr(a, mode, m, n, k, res_type):
14231451
batch_shape = a.shape[:-2]
14241452

14251453
if mode == "reduced":
1426-
return (
1454+
return QRResult(
14271455
dpnp.empty_like(
14281456
a,
14291457
shape=batch_shape + (m, k),
@@ -1443,7 +1471,7 @@ def _zero_batched_qr(a, mode, m, n, k, res_type):
14431471
usm_type=a_usm_type,
14441472
sycl_queue=a_sycl_queue,
14451473
)
1446-
return (
1474+
return QRResult(
14471475
q,
14481476
dpnp.empty_like(
14491477
a,
@@ -1530,7 +1558,7 @@ def _zero_batched_svd(
15301558
usm_type=usm_type,
15311559
sycl_queue=exec_q,
15321560
)
1533-
return u, s, vt
1561+
return SVDResult(u, s, vt)
15341562
return s
15351563

15361564

@@ -1548,22 +1576,28 @@ def _zero_k_qr(a, mode, m, n, res_type):
15481576
m, n = a.shape
15491577

15501578
if mode == "reduced":
1551-
return dpnp.empty_like(
1552-
a,
1553-
shape=(m, 0),
1554-
dtype=res_type,
1555-
), dpnp.empty_like(
1556-
a,
1557-
shape=(0, n),
1558-
dtype=res_type,
1579+
return QRResult(
1580+
dpnp.empty_like(
1581+
a,
1582+
shape=(m, 0),
1583+
dtype=res_type,
1584+
),
1585+
dpnp.empty_like(
1586+
a,
1587+
shape=(0, n),
1588+
dtype=res_type,
1589+
),
15591590
)
15601591
if mode == "complete":
1561-
return dpnp.identity(
1562-
m, dtype=res_type, sycl_queue=a_sycl_queue, usm_type=a_usm_type
1563-
), dpnp.empty_like(
1564-
a,
1565-
shape=(m, n),
1566-
dtype=res_type,
1592+
return QRResult(
1593+
dpnp.identity(
1594+
m, dtype=res_type, sycl_queue=a_sycl_queue, usm_type=a_usm_type
1595+
),
1596+
dpnp.empty_like(
1597+
a,
1598+
shape=(m, n),
1599+
dtype=res_type,
1600+
),
15671601
)
15681602
if mode == "r":
15691603
return dpnp.empty_like(
@@ -1648,7 +1682,7 @@ def _zero_m_n_batched_svd(
16481682
usm_type=usm_type,
16491683
sycl_queue=exec_q,
16501684
)
1651-
return u, s, vt
1685+
return SVDResult(u, s, vt)
16521686
return s
16531687

16541688

@@ -1692,7 +1726,7 @@ def _zero_m_n_svd(
16921726
usm_type=usm_type,
16931727
sycl_queue=exec_q,
16941728
)
1695-
return u, s, vt
1729+
return SVDResult(u, s, vt)
16961730
return s
16971731

16981732

@@ -1993,7 +2027,7 @@ def dpnp_det(a):
19932027
return det.reshape(shape)
19942028

19952029

1996-
def dpnp_eigh(a, UPLO, eigen_mode="V"):
2030+
def dpnp_eigh(a, UPLO="L", eigen_mode="V"):
19972031
"""
19982032
dpnp_eigh(a, UPLO, eigen_mode="V")
19992033
@@ -2016,7 +2050,7 @@ def dpnp_eigh(a, UPLO, eigen_mode="V"):
20162050
w = dpnp.empty_like(a, shape=a.shape[:-1], dtype=w_type)
20172051
if eigen_mode == "V":
20182052
v = dpnp.empty_like(a, dtype=v_type)
2019-
return w, v
2053+
return EighResult(w, v)
20202054
return w
20212055

20222056
if a.ndim > 2:
@@ -2097,7 +2131,7 @@ def dpnp_eigh(a, UPLO, eigen_mode="V"):
20972131
else:
20982132
out_v = v
20992133

2100-
return (w, out_v) if eigen_mode == "V" else w
2134+
return EighResult(w, out_v) if eigen_mode == "V" else w
21012135

21022136

21032137
def dpnp_inv(a):
@@ -2546,7 +2580,7 @@ def dpnp_qr(a, mode="reduced"):
25462580
r = a_t[:, :mc].transpose()
25472581

25482582
r = _triu_inplace(r)
2549-
return (q, r)
2583+
return QRResult(q, r)
25502584

25512585

25522586
def dpnp_solve(a, b):
@@ -2675,7 +2709,7 @@ def dpnp_slogdet(a):
26752709
usm_type=a_usm_type,
26762710
sycl_queue=a_sycl_queue,
26772711
)
2678-
return sign, logdet
2712+
return SlogdetResult(sign, logdet)
26792713

26802714
lu, ipiv, dev_info = _lu_factor(a, res_type)
26812715

@@ -2687,7 +2721,7 @@ def dpnp_slogdet(a):
26872721

26882722
logdet = logdet.astype(logdet_dtype, copy=False)
26892723
singular = dev_info > 0
2690-
return (
2724+
return SlogdetResult(
26912725
dpnp.where(singular, res_type.type(0), sign).reshape(shape),
26922726
dpnp.where(singular, logdet_dtype.type("-inf"), logdet).reshape(shape),
26932727
)
@@ -2815,10 +2849,10 @@ def dpnp_svd(
28152849
# For A^T = V S^T U^T, `u_h` becomes V and `vt_h` becomes U^T.
28162850
# Transpose and swap them back to restore correct order for A.
28172851
if trans_flag:
2818-
return vt_h.T, s_h, u_h.T
2852+
return SVDResult(vt_h.T, s_h, u_h.T)
28192853
# gesvd call writes `u_h` and `vt_h` in Fortran order;
28202854
# Convert to contiguous to align with NumPy
28212855
u_h = dpnp.ascontiguousarray(u_h)
28222856
vt_h = dpnp.ascontiguousarray(vt_h)
2823-
return u_h, s_h, vt_h
2857+
return SVDResult(u_h, s_h, vt_h)
28242858
return s_h

0 commit comments

Comments
 (0)