3838# pylint: disable=protected-access
3939# pylint: disable=useless-import-alias
4040
41+ from typing import NamedTuple
42+
4143import dpctl .tensor ._tensor_impl as ti
4244import dpctl .utils as dpu
4345import numpy
5052from dpnp .linalg import LinAlgError as LinAlgError
5153
5254__all__ = [
55+ "EighResult" ,
56+ "QRResult" ,
57+ "SlogdetResult" ,
58+ "SVDResult" ,
5359 "assert_2d" ,
5460 "assert_stacked_2d" ,
5561 "assert_stacked_square" ,
7076 "dpnp_svd" ,
7177]
7278
79+
80+ # pylint:disable=missing-class-docstring
81+ class EighResult (NamedTuple ):
82+ eigenvalues : dpnp .ndarray
83+ eigenvectors : dpnp .ndarray
84+
85+
86+ class QRResult (NamedTuple ):
87+ Q : dpnp .ndarray
88+ R : dpnp .ndarray
89+
90+
91+ class SlogdetResult (NamedTuple ):
92+ sign : dpnp .ndarray
93+ logabsdet : dpnp .ndarray
94+
95+
96+ class SVDResult (NamedTuple ):
97+ U : dpnp .ndarray
98+ S : dpnp .ndarray
99+ Vh : dpnp .ndarray
100+
101+
73102_jobz = {"N" : 0 , "V" : 1 }
74103_upper_lower = {"U" : 0 , "L" : 1 }
75104
@@ -162,7 +191,7 @@ def _batched_eigh(a, UPLO, eigen_mode, w_type, v_type):
162191 # Convert to contiguous to align with NumPy
163192 if a_orig_order == "C" :
164193 v = dpnp .ascontiguousarray (v )
165- return w , v
194+ return EighResult ( w , v )
166195 return w
167196
168197
@@ -476,7 +505,7 @@ def _batched_qr(a, mode="reduced"):
476505
477506 r = _triu_inplace (r )
478507
479- return (
508+ return QRResult (
480509 q .reshape (batch_shape + q .shape [- 2 :]),
481510 r .reshape (batch_shape + r .shape [- 2 :]),
482511 )
@@ -632,7 +661,7 @@ def _batched_svd(
632661 u = dpnp .ascontiguousarray (u )
633662 vt = dpnp .ascontiguousarray (vt )
634663 # Swap `u` and `vt` for transposed input to restore correct order
635- return (vt , s , u ) if trans_flag else (u , s , vt )
664+ return SVDResult (vt , s , u ) if trans_flag else SVDResult (u , s , vt )
636665 return s
637666
638667
@@ -819,9 +848,9 @@ def _hermitian_svd(a, compute_uv):
819848 # but dpnp.linalg.eigh returns s sorted ascending so we re-order
820849 # the eigenvalues and related arrays to have the correct order
821850 if compute_uv :
822- s , u = dpnp . linalg . eigh ( a )
823- sgn = dpnp .sign (s )
824- s = dpnp .absolute ( s )
851+ s , u = s = dpnp_eigh ( a , eigen_mode = "V" )
852+ sgn = dpnp .sign (s , out = s )
853+ s = dpnp .abs ( s , out = s )
825854 sidx = dpnp .argsort (s )[..., ::- 1 ]
826855 # Rearrange the signs according to sorted indices
827856 sgn = dpnp .take_along_axis (sgn , sidx , axis = - 1 )
@@ -832,11 +861,10 @@ def _hermitian_svd(a, compute_uv):
832861 # Singular values are unsigned, move the sign into v
833862 # Compute V^T adjusting for the sign and conjugating
834863 vt = dpnp .transpose (u * sgn [..., None , :]).conjugate ()
835- return u , s , vt
864+ return SVDResult ( u , s , vt )
836865
837- # TODO: use dpnp.linalg.eighvals when it is updated
838- s , _ = dpnp .linalg .eigh (a )
839- s = dpnp .abs (s )
866+ s = dpnp_eigh (a , eigen_mode = "N" )
867+ s = dpnp .abs (s , out = s )
840868 return dpnp .sort (s )[..., ::- 1 ]
841869
842870
@@ -1423,7 +1451,7 @@ def _zero_batched_qr(a, mode, m, n, k, res_type):
14231451 batch_shape = a .shape [:- 2 ]
14241452
14251453 if mode == "reduced" :
1426- return (
1454+ return QRResult (
14271455 dpnp .empty_like (
14281456 a ,
14291457 shape = batch_shape + (m , k ),
@@ -1443,7 +1471,7 @@ def _zero_batched_qr(a, mode, m, n, k, res_type):
14431471 usm_type = a_usm_type ,
14441472 sycl_queue = a_sycl_queue ,
14451473 )
1446- return (
1474+ return QRResult (
14471475 q ,
14481476 dpnp .empty_like (
14491477 a ,
@@ -1530,7 +1558,7 @@ def _zero_batched_svd(
15301558 usm_type = usm_type ,
15311559 sycl_queue = exec_q ,
15321560 )
1533- return u , s , vt
1561+ return SVDResult ( u , s , vt )
15341562 return s
15351563
15361564
@@ -1548,22 +1576,28 @@ def _zero_k_qr(a, mode, m, n, res_type):
15481576 m , n = a .shape
15491577
15501578 if mode == "reduced" :
1551- return dpnp .empty_like (
1552- a ,
1553- shape = (m , 0 ),
1554- dtype = res_type ,
1555- ), dpnp .empty_like (
1556- a ,
1557- shape = (0 , n ),
1558- dtype = res_type ,
1579+ return QRResult (
1580+ dpnp .empty_like (
1581+ a ,
1582+ shape = (m , 0 ),
1583+ dtype = res_type ,
1584+ ),
1585+ dpnp .empty_like (
1586+ a ,
1587+ shape = (0 , n ),
1588+ dtype = res_type ,
1589+ ),
15591590 )
15601591 if mode == "complete" :
1561- return dpnp .identity (
1562- m , dtype = res_type , sycl_queue = a_sycl_queue , usm_type = a_usm_type
1563- ), dpnp .empty_like (
1564- a ,
1565- shape = (m , n ),
1566- dtype = res_type ,
1592+ return QRResult (
1593+ dpnp .identity (
1594+ m , dtype = res_type , sycl_queue = a_sycl_queue , usm_type = a_usm_type
1595+ ),
1596+ dpnp .empty_like (
1597+ a ,
1598+ shape = (m , n ),
1599+ dtype = res_type ,
1600+ ),
15671601 )
15681602 if mode == "r" :
15691603 return dpnp .empty_like (
@@ -1648,7 +1682,7 @@ def _zero_m_n_batched_svd(
16481682 usm_type = usm_type ,
16491683 sycl_queue = exec_q ,
16501684 )
1651- return u , s , vt
1685+ return SVDResult ( u , s , vt )
16521686 return s
16531687
16541688
@@ -1692,7 +1726,7 @@ def _zero_m_n_svd(
16921726 usm_type = usm_type ,
16931727 sycl_queue = exec_q ,
16941728 )
1695- return u , s , vt
1729+ return SVDResult ( u , s , vt )
16961730 return s
16971731
16981732
@@ -1993,7 +2027,7 @@ def dpnp_det(a):
19932027 return det .reshape (shape )
19942028
19952029
1996- def dpnp_eigh (a , UPLO , eigen_mode = "V" ):
2030+ def dpnp_eigh (a , UPLO = "L" , eigen_mode = "V" ):
19972031 """
19982032 dpnp_eigh(a, UPLO, eigen_mode="V")
19992033
@@ -2016,7 +2050,7 @@ def dpnp_eigh(a, UPLO, eigen_mode="V"):
20162050 w = dpnp .empty_like (a , shape = a .shape [:- 1 ], dtype = w_type )
20172051 if eigen_mode == "V" :
20182052 v = dpnp .empty_like (a , dtype = v_type )
2019- return w , v
2053+ return EighResult ( w , v )
20202054 return w
20212055
20222056 if a .ndim > 2 :
@@ -2097,7 +2131,7 @@ def dpnp_eigh(a, UPLO, eigen_mode="V"):
20972131 else :
20982132 out_v = v
20992133
2100- return (w , out_v ) if eigen_mode == "V" else w
2134+ return EighResult (w , out_v ) if eigen_mode == "V" else w
21012135
21022136
21032137def dpnp_inv (a ):
@@ -2546,7 +2580,7 @@ def dpnp_qr(a, mode="reduced"):
25462580 r = a_t [:, :mc ].transpose ()
25472581
25482582 r = _triu_inplace (r )
2549- return (q , r )
2583+ return QRResult (q , r )
25502584
25512585
25522586def dpnp_solve (a , b ):
@@ -2675,7 +2709,7 @@ def dpnp_slogdet(a):
26752709 usm_type = a_usm_type ,
26762710 sycl_queue = a_sycl_queue ,
26772711 )
2678- return sign , logdet
2712+ return SlogdetResult ( sign , logdet )
26792713
26802714 lu , ipiv , dev_info = _lu_factor (a , res_type )
26812715
@@ -2687,7 +2721,7 @@ def dpnp_slogdet(a):
26872721
26882722 logdet = logdet .astype (logdet_dtype , copy = False )
26892723 singular = dev_info > 0
2690- return (
2724+ return SlogdetResult (
26912725 dpnp .where (singular , res_type .type (0 ), sign ).reshape (shape ),
26922726 dpnp .where (singular , logdet_dtype .type ("-inf" ), logdet ).reshape (shape ),
26932727 )
@@ -2815,10 +2849,10 @@ def dpnp_svd(
28152849 # For A^T = V S^T U^T, `u_h` becomes V and `vt_h` becomes U^T.
28162850 # Transpose and swap them back to restore correct order for A.
28172851 if trans_flag :
2818- return vt_h .T , s_h , u_h .T
2852+ return SVDResult ( vt_h .T , s_h , u_h .T )
28192853 # gesvd call writes `u_h` and `vt_h` in Fortran order;
28202854 # Convert to contiguous to align with NumPy
28212855 u_h = dpnp .ascontiguousarray (u_h )
28222856 vt_h = dpnp .ascontiguousarray (vt_h )
2823- return u_h , s_h , vt_h
2857+ return SVDResult ( u_h , s_h , vt_h )
28242858 return s_h
0 commit comments