Skip to content

Commit b099309

Browse files
committed
Improving codestyle, updating documentation.
1 parent ed0912e commit b099309

File tree

1 file changed

+9
-15
lines changed

1 file changed

+9
-15
lines changed

pydmd/utils.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from numpy.lib.stride_tricks import sliding_window_view
88

99

10-
def __svht(sigma_svd: np.ndarray, rows: int, cols: int) -> int:
10+
def _svht(sigma_svd: np.ndarray, rows: int, cols: int) -> int:
1111
"""
1212
Singular Value Hard Threshold.
1313
@@ -27,8 +27,7 @@ def __svht(sigma_svd: np.ndarray, rows: int, cols: int) -> int:
2727
https://ieeexplore.ieee.org/document/6846297
2828
"""
2929
beta = np.divide(*sorted((rows, cols)))
30-
beta_square = beta * beta
31-
omega = 0.56 * beta_square * beta - 0.95 * beta_square + 1.82 * beta + 1.43
30+
omega = 0.56 * beta**3 - 0.95 * beta**2 + 1.82 * beta + 1.43
3231
tau = np.median(sigma_svd) * omega
3332
rank = np.sum(sigma_svd > tau)
3433

@@ -43,12 +42,13 @@ def __svht(sigma_svd: np.ndarray, rows: int, cols: int) -> int:
4342
return rank
4443

4544

46-
def __compute_rank(
45+
def _compute_rank(
4746
sigma_svd: np.ndarray, rows: int, cols: int, svd_rank: Union[float, int]
4847
) -> int:
4948
"""
5049
Rank computation for the truncated Singular Value Decomposition.
51-
:param numpy.ndarray X: the matrix to decompose.
50+
:param sigma_svd: 1D singular values of SVD.
51+
:type sigma_svd: np.ndarray
5252
:param svd_rank: the rank for the truncation; If 0, the method computes
5353
the optimal rank and uses it for truncation; if positive interger,
5454
the method uses the argument for the truncation; if float between 0
@@ -64,18 +64,12 @@ def __compute_rank(
6464
(2014): 5040-5053.
6565
"""
6666
if svd_rank == 0:
67-
rank = __svht(sigma_svd, rows, cols)
68-
67+
rank = _svht(sigma_svd, rows, cols)
6968
elif 0 < svd_rank < 1:
70-
sigma_svd_squared = np.square(sigma_svd)
71-
cumulative_energy = np.cumsum(
72-
sigma_svd_squared / sigma_svd_squared.sum()
73-
)
69+
cumulative_energy = np.cumsum(sigma_svd**2 / (sigma_svd**2).sum())
7470
rank = np.searchsorted(cumulative_energy, svd_rank) + 1
75-
7671
elif svd_rank >= 1 and isinstance(svd_rank, int):
7772
rank = min(svd_rank, sigma_svd.size)
78-
7973
else:
8074
rank = min(rows, cols)
8175

@@ -101,7 +95,7 @@ def compute_rank(X: np.ndarray, svd_rank=0):
10195
(2014): 5040-5053.
10296
"""
10397
_, s, _ = np.linalg.svd(X, full_matrices=False)
104-
return __compute_rank(s, X.shape[0], X.shape[1], svd_rank)
98+
return _compute_rank(s, X.shape[0], X.shape[1], svd_rank)
10599

106100

107101
def compute_tlsq(X, Y, tlsq_rank):
@@ -154,7 +148,7 @@ def compute_svd(X, svd_rank=0):
154148
(2014): 5040-5053.
155149
"""
156150
U, s, V = np.linalg.svd(X, full_matrices=False)
157-
rank = __compute_rank(s, X.shape[0], X.shape[1], svd_rank)
151+
rank = _compute_rank(s, X.shape[0], X.shape[1], svd_rank)
158152
V = V.conj().T
159153

160154
U = U[:, :rank]

0 commit comments

Comments
 (0)