Skip to content

Commit 815afcd

Browse files
flabowskindem0
authored andcommitted
explicit arguments
1 parent 6fdca17 commit 815afcd

File tree

1 file changed

+31
-19
lines changed

1 file changed

+31
-19
lines changed

ezyrb/reduction/pod.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111

1212

1313
class POD(Reduction):
14-
def __init__(self, method="svd", **kwargs):
14+
def __init__(
15+
self,
16+
method="svd",
17+
rank=-1,
18+
subspace_iteration=1,
19+
omega_rank=0,
20+
save_memory=False,
21+
):
1522
"""
1623
Perform the Proper Orthogonal Decomposition.
1724
@@ -50,27 +57,23 @@ def __init__(self, method="svd", **kwargs):
5057
omega_rank=10)
5158
>>> pod = POD('correlation_matrix', rank=10, save_memory=False)
5259
"""
53-
available_methods = {
54-
"svd": (self._svd, {"rank": -1}),
55-
"randomized_svd": (self._rsvd, {"rank": -1, "subspace_iteration": 1, "omega_rank": 0}),
56-
"correlation_matrix": (self._corrm, {"rank": -1, "save_memory": False}),
57-
}
60+
self.available_methods = ["svd", "randomized_svd", "correlation_matrix"]
61+
self.rank = rank
62+
if method == "svd":
63+
self._method = self._svd
64+
elif method == "randomized_svd":
65+
self.subspace_iteration = subspace_iteration
66+
self.omega_rank = omega_rank
67+
self._method = self._rsvd
68+
elif method == "correlation_matrix":
69+
self.save_memory = save_memory
70+
self._method = self._corrm
71+
else:
72+
self._method = None
5873

5974
self._modes = None
6075
self._singular_values = None
6176

62-
method = available_methods.get(method)
63-
if method is None:
64-
raise RuntimeError(
65-
f"Invalid method for POD. Please chose one among {', '.join(available_methods)}"
66-
)
67-
68-
self._method, args = method
69-
args.update(kwargs)
70-
71-
for hyperparam, value in args.items():
72-
setattr(self, hyperparam, value)
73-
7477
@property
7578
def modes(self):
7679
"""
@@ -96,6 +99,11 @@ def fit(self, X):
9699
97100
:param numpy.ndarray X: the input snapshots matrix (stored by column)
98101
"""
102+
if self._method is None:
103+
m = self.available_methods
104+
raise RuntimeError(
105+
f"Invalid method for POD. Please chose one among {', '.join(m)}"
106+
)
99107
self._modes, self._singular_values = self._method(X)
100108
return self
101109

@@ -201,7 +209,11 @@ def _rsvd(self, X):
201209
constructing approximate matrix decompositions. N. Halko, P. G.
202210
Martinsson, J. A. Tropp.
203211
"""
204-
if self.omega_rank == 0 and isinstance(self.rank, int) and self.rank not in [0, -1]:
212+
if (
213+
self.omega_rank == 0
214+
and isinstance(self.rank, int)
215+
and self.rank not in [0, -1]
216+
):
205217
omega_rank = self.rank * 2
206218
elif self.omega_rank == 0:
207219
omega_rank = X.shape[1] * 2

0 commit comments

Comments
 (0)