Skip to content

Commit ed0912e

Browse files
committed
Fixing redundant SVD computation
1 parent 065dccf commit ed0912e

File tree

1 file changed

+73
-20
lines changed

1 file changed

+73
-20
lines changed

pydmd/utils.py

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,51 @@
11
"""Utilities module."""
22

33
import warnings
4+
from typing import Union
45

56
import numpy as np
67
from numpy.lib.stride_tricks import sliding_window_view
78

89

9-
def compute_rank(X, svd_rank=0):
10+
def __svht(sigma_svd: np.ndarray, rows: int, cols: int) -> int:
11+
"""
12+
Singular Value Hard Threshold.
13+
14+
:param sigma_svd: Singual values computed by SVD
15+
:type sigma_svd: np.ndarray
16+
:param rows: Number of rows of original data matrix.
17+
:type rows: int
18+
:param cols: Number of columns of original data matrix.
19+
:type cols: int
20+
:return: Computed rank.
21+
:rtype: int
22+
23+
References:
24+
Gavish, Matan, and David L. Donoho, The optimal hard threshold for
25+
singular values is, IEEE Transactions on Information Theory 60.8
26+
(2014): 5040-5053.
27+
https://ieeexplore.ieee.org/document/6846297
28+
"""
29+
beta = np.divide(*sorted((rows, cols)))
30+
beta_square = beta * beta
31+
omega = 0.56 * beta_square * beta - 0.95 * beta_square + 1.82 * beta + 1.43
32+
tau = np.median(sigma_svd) * omega
33+
rank = np.sum(sigma_svd > tau)
34+
35+
if rank == 0:
36+
warnings.warn(
37+
"SVD optimal rank is 0. The largest singular values are "
38+
"indistinguishable from noise. Setting rank truncation to 1.",
39+
RuntimeWarning,
40+
)
41+
rank = 1
42+
43+
return rank
44+
45+
46+
def __compute_rank(
47+
sigma_svd: np.ndarray, rows: int, cols: int, svd_rank: Union[float, int]
48+
) -> int:
1049
"""
1150
Rank computation for the truncated Singular Value Decomposition.
1251
:param numpy.ndarray X: the matrix to decompose.
@@ -24,33 +63,47 @@ def compute_rank(X, svd_rank=0):
2463
singular values is, IEEE Transactions on Information Theory 60.8
2564
(2014): 5040-5053.
2665
"""
27-
U, s, _ = np.linalg.svd(X, full_matrices=False)
28-
29-
def omega(x):
30-
return 0.56 * x**3 - 0.95 * x**2 + 1.82 * x + 1.43
31-
3266
if svd_rank == 0:
33-
beta = np.divide(*sorted(X.shape))
34-
tau = np.median(s) * omega(beta)
35-
rank = np.sum(s > tau)
36-
if rank == 0:
37-
warnings.warn(
38-
"SVD optimal rank is 0. The largest singular values are "
39-
"indistinguishable from noise. Setting rank truncation to 1.",
40-
RuntimeWarning,
41-
)
42-
rank = 1
67+
rank = __svht(sigma_svd, rows, cols)
68+
4369
elif 0 < svd_rank < 1:
44-
cumulative_energy = np.cumsum(s**2 / (s**2).sum())
70+
sigma_svd_squared = np.square(sigma_svd)
71+
cumulative_energy = np.cumsum(
72+
sigma_svd_squared / sigma_svd_squared.sum()
73+
)
4574
rank = np.searchsorted(cumulative_energy, svd_rank) + 1
75+
4676
elif svd_rank >= 1 and isinstance(svd_rank, int):
47-
rank = min(svd_rank, U.shape[1])
77+
rank = min(svd_rank, sigma_svd.size)
78+
4879
else:
49-
rank = min(X.shape)
80+
rank = min(rows, cols)
5081

5182
return rank
5283

5384

85+
def compute_rank(X: np.ndarray, svd_rank=0):
86+
"""
87+
Rank computation for the truncated Singular Value Decomposition.
88+
:param numpy.ndarray X: the matrix to decompose.
89+
:param svd_rank: the rank for the truncation; If 0, the method computes
90+
the optimal rank and uses it for truncation; if positive interger,
91+
the method uses the argument for the truncation; if float between 0
92+
and 1, the rank is the number of the biggest singular values that
93+
are needed to reach the 'energy' specified by `svd_rank`; if -1,
94+
the method does not compute truncation. Default is 0.
95+
:type svd_rank: int or float
96+
:return: the computed rank truncation.
97+
:rtype: int
98+
References:
99+
Gavish, Matan, and David L. Donoho, The optimal hard threshold for
100+
singular values is, IEEE Transactions on Information Theory 60.8
101+
(2014): 5040-5053.
102+
"""
103+
_, s, _ = np.linalg.svd(X, full_matrices=False)
104+
return __compute_rank(s, X.shape[0], X.shape[1], svd_rank)
105+
106+
54107
def compute_tlsq(X, Y, tlsq_rank):
55108
"""
56109
Compute Total Least Square.
@@ -100,8 +153,8 @@ def compute_svd(X, svd_rank=0):
100153
singular values is, IEEE Transactions on Information Theory 60.8
101154
(2014): 5040-5053.
102155
"""
103-
rank = compute_rank(X, svd_rank)
104156
U, s, V = np.linalg.svd(X, full_matrices=False)
157+
rank = __compute_rank(s, X.shape[0], X.shape[1], svd_rank)
105158
V = V.conj().T
106159

107160
U = U[:, :rank]

0 commit comments

Comments
 (0)