11"""Utilities module."""
22
33import warnings
4+ from typing import Union
45
56import numpy as np
67from numpy .lib .stride_tricks import sliding_window_view
78
89
9- def compute_rank (X , svd_rank = 0 ):
10+ def __svht (sigma_svd : np .ndarray , rows : int , cols : int ) -> int :
11+ """
12+ Singular Value Hard Threshold.
13+
14+ :param sigma_svd: Singual values computed by SVD
15+ :type sigma_svd: np.ndarray
16+ :param rows: Number of rows of original data matrix.
17+ :type rows: int
18+ :param cols: Number of columns of original data matrix.
19+ :type cols: int
20+ :return: Computed rank.
21+ :rtype: int
22+
23+ References:
24+ Gavish, Matan, and David L. Donoho, The optimal hard threshold for
25+ singular values is, IEEE Transactions on Information Theory 60.8
26+ (2014): 5040-5053.
27+ https://ieeexplore.ieee.org/document/6846297
28+ """
29+ beta = np .divide (* sorted ((rows , cols )))
30+ beta_square = beta * beta
31+ omega = 0.56 * beta_square * beta - 0.95 * beta_square + 1.82 * beta + 1.43
32+ tau = np .median (sigma_svd ) * omega
33+ rank = np .sum (sigma_svd > tau )
34+
35+ if rank == 0 :
36+ warnings .warn (
37+ "SVD optimal rank is 0. The largest singular values are "
38+ "indistinguishable from noise. Setting rank truncation to 1." ,
39+ RuntimeWarning ,
40+ )
41+ rank = 1
42+
43+ return rank
44+
45+
46+ def __compute_rank (
47+ sigma_svd : np .ndarray , rows : int , cols : int , svd_rank : Union [float , int ]
48+ ) -> int :
1049 """
1150 Rank computation for the truncated Singular Value Decomposition.
1251 :param numpy.ndarray X: the matrix to decompose.
@@ -24,33 +63,47 @@ def compute_rank(X, svd_rank=0):
2463 singular values is, IEEE Transactions on Information Theory 60.8
2564 (2014): 5040-5053.
2665 """
27- U , s , _ = np .linalg .svd (X , full_matrices = False )
28-
29- def omega (x ):
30- return 0.56 * x ** 3 - 0.95 * x ** 2 + 1.82 * x + 1.43
31-
3266 if svd_rank == 0 :
33- beta = np .divide (* sorted (X .shape ))
34- tau = np .median (s ) * omega (beta )
35- rank = np .sum (s > tau )
36- if rank == 0 :
37- warnings .warn (
38- "SVD optimal rank is 0. The largest singular values are "
39- "indistinguishable from noise. Setting rank truncation to 1." ,
40- RuntimeWarning ,
41- )
42- rank = 1
67+ rank = __svht (sigma_svd , rows , cols )
68+
4369 elif 0 < svd_rank < 1 :
44- cumulative_energy = np .cumsum (s ** 2 / (s ** 2 ).sum ())
70+ sigma_svd_squared = np .square (sigma_svd )
71+ cumulative_energy = np .cumsum (
72+ sigma_svd_squared / sigma_svd_squared .sum ()
73+ )
4574 rank = np .searchsorted (cumulative_energy , svd_rank ) + 1
75+
4676 elif svd_rank >= 1 and isinstance (svd_rank , int ):
47- rank = min (svd_rank , U .shape [1 ])
77+ rank = min (svd_rank , sigma_svd .size )
78+
4879 else :
49- rank = min (X . shape )
80+ rank = min (rows , cols )
5081
5182 return rank
5283
5384
85+ def compute_rank (X : np .ndarray , svd_rank = 0 ):
86+ """
87+ Rank computation for the truncated Singular Value Decomposition.
88+ :param numpy.ndarray X: the matrix to decompose.
89+ :param svd_rank: the rank for the truncation; If 0, the method computes
90+ the optimal rank and uses it for truncation; if positive interger,
91+ the method uses the argument for the truncation; if float between 0
92+ and 1, the rank is the number of the biggest singular values that
93+ are needed to reach the 'energy' specified by `svd_rank`; if -1,
94+ the method does not compute truncation. Default is 0.
95+ :type svd_rank: int or float
96+ :return: the computed rank truncation.
97+ :rtype: int
98+ References:
99+ Gavish, Matan, and David L. Donoho, The optimal hard threshold for
100+ singular values is, IEEE Transactions on Information Theory 60.8
101+ (2014): 5040-5053.
102+ """
103+ _ , s , _ = np .linalg .svd (X , full_matrices = False )
104+ return __compute_rank (s , X .shape [0 ], X .shape [1 ], svd_rank )
105+
106+
54107def compute_tlsq (X , Y , tlsq_rank ):
55108 """
56109 Compute Total Least Square.
@@ -100,8 +153,8 @@ def compute_svd(X, svd_rank=0):
100153 singular values is, IEEE Transactions on Information Theory 60.8
101154 (2014): 5040-5053.
102155 """
103- rank = compute_rank (X , svd_rank )
104156 U , s , V = np .linalg .svd (X , full_matrices = False )
157+ rank = __compute_rank (s , X .shape [0 ], X .shape [1 ], svd_rank )
105158 V = V .conj ().T
106159
107160 U = U [:, :rank ]
0 commit comments