@@ -30,6 +30,7 @@ def __init__(self, rank, scale_coefficients=True):
3030 super ().__init__ ()
3131 self .__scale_coefficients = scale_coefficients
3232 self ._basis = None
33+ self ._singular_values = None
3334 self ._scaler = None
3435 self ._rank = rank
3536
@@ -70,6 +71,19 @@ def basis(self):
7071
7172 return self ._basis [: self .rank ]
7273
74+ @property
75+ def singular_values (self ):
76+ """
77+ The singular values of the POD basis.
78+
79+ :return: The singular values.
80+ :rtype: torch.Tensor
81+ """
82+ if self ._singular_values is None :
83+ return None
84+
85+ return self ._singular_values [: self .rank ]
86+
7387 @property
7488 def scaler (self ):
7589 """
@@ -136,15 +150,19 @@ def _fit_pod(self, X, randomized):
136150 "This may slow down computations." ,
137151 ResourceWarning ,
138152 )
139- self . _basis = torch .svd (X .T )[ 0 ]. T
153+ u , s , v = torch .svd (X .T )
140154 else :
141155 if randomized :
142156 warnings .warn (
143157 "Considering a randomized algorithm to compute the POD basis"
144158 )
145- self ._basis = torch .svd_lowrank (X .T , q = X .shape [0 ])[0 ].T
159+ u , s , v = torch .svd_lowrank (X .T , q = X .shape [0 ])
160+
146161 else :
147- self ._basis = torch .svd (X .T )[0 ].T
162+ u , s , v = torch .svd (X .T )
163+ self ._basis = u .T
164+ self ._singular_values = s
165+
148166
149167 def forward (self , X ):
150168 """
0 commit comments