77from numpy .lib .stride_tricks import sliding_window_view
88
99
10- def __svht (sigma_svd : np .ndarray , rows : int , cols : int ) -> int :
10+ def _svht (sigma_svd : np .ndarray , rows : int , cols : int ) -> int :
1111 """
1212 Singular Value Hard Threshold.
1313
@@ -27,8 +27,7 @@ def __svht(sigma_svd: np.ndarray, rows: int, cols: int) -> int:
2727 https://ieeexplore.ieee.org/document/6846297
2828 """
2929 beta = np .divide (* sorted ((rows , cols )))
30- beta_square = beta * beta
31- omega = 0.56 * beta_square * beta - 0.95 * beta_square + 1.82 * beta + 1.43
30+ omega = 0.56 * beta ** 3 - 0.95 * beta ** 2 + 1.82 * beta + 1.43
3231 tau = np .median (sigma_svd ) * omega
3332 rank = np .sum (sigma_svd > tau )
3433
@@ -43,12 +42,13 @@ def __svht(sigma_svd: np.ndarray, rows: int, cols: int) -> int:
4342 return rank
4443
4544
46- def __compute_rank (
45+ def _compute_rank (
4746 sigma_svd : np .ndarray , rows : int , cols : int , svd_rank : Union [float , int ]
4847) -> int :
4948 """
5049 Rank computation for the truncated Singular Value Decomposition.
51- :param numpy.ndarray X: the matrix to decompose.
50+ :param sigma_svd: 1D singular values of SVD.
51+ :type sigma_svd: np.ndarray
5252 :param svd_rank: the rank for the truncation; If 0, the method computes
5353 the optimal rank and uses it for truncation; if positive interger,
5454 the method uses the argument for the truncation; if float between 0
@@ -64,18 +64,12 @@ def __compute_rank(
6464 (2014): 5040-5053.
6565 """
6666 if svd_rank == 0 :
67- rank = __svht (sigma_svd , rows , cols )
68-
67+ rank = _svht (sigma_svd , rows , cols )
6968 elif 0 < svd_rank < 1 :
70- sigma_svd_squared = np .square (sigma_svd )
71- cumulative_energy = np .cumsum (
72- sigma_svd_squared / sigma_svd_squared .sum ()
73- )
69+ cumulative_energy = np .cumsum (sigma_svd ** 2 / (sigma_svd ** 2 ).sum ())
7470 rank = np .searchsorted (cumulative_energy , svd_rank ) + 1
75-
7671 elif svd_rank >= 1 and isinstance (svd_rank , int ):
7772 rank = min (svd_rank , sigma_svd .size )
78-
7973 else :
8074 rank = min (rows , cols )
8175
@@ -101,7 +95,7 @@ def compute_rank(X: np.ndarray, svd_rank=0):
10195 (2014): 5040-5053.
10296 """
10397 _ , s , _ = np .linalg .svd (X , full_matrices = False )
104- return __compute_rank (s , X .shape [0 ], X .shape [1 ], svd_rank )
98+ return _compute_rank (s , X .shape [0 ], X .shape [1 ], svd_rank )
10599
106100
107101def compute_tlsq (X , Y , tlsq_rank ):
@@ -154,7 +148,7 @@ def compute_svd(X, svd_rank=0):
154148 (2014): 5040-5053.
155149 """
156150 U , s , V = np .linalg .svd (X , full_matrices = False )
157- rank = __compute_rank (s , X .shape [0 ], X .shape [1 ], svd_rank )
151+ rank = _compute_rank (s , X .shape [0 ], X .shape [1 ], svd_rank )
158152 V = V .conj ().T
159153
160154 U = U [:, :rank ]
0 commit comments