Skip to content

Commit 4546064

Browse files
flabowskindem0
authored andcommitted
reset changes in parallel module
1 parent 815afcd commit 4546064

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

ezyrb/parallel/pod.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
Decomposition, Truncated Randomized Singular Value Decomposition, Truncated
55
Singular Value Decomposition via correlation matrix.
66
"""
7-
87
try:
98
from scipy.linalg import eigh
109
except ImportError:
@@ -15,7 +14,6 @@
1514
from pycompss.api.parameter import INOUT, IN
1615
from .reduction import Reduction
1716

18-
1917
class POD(Reduction):
2018
"""
2119
Perform the Proper Orthogonal Decomposition.
@@ -55,13 +53,21 @@ class POD(Reduction):
5553
omega_rank=10)
5654
>>> pod = POD('correlation_matrix', rank=10, save_memory=False)
5755
"""
58-
59-
def __init__(self, method="svd", **kwargs):
56+
def __init__(self, method='svd', **kwargs):
6057

6158
available_methods = {
62-
"svd": (self._svd, {"rank": -1}),
63-
"randomized_svd": (self._rsvd, {"rank": -1, "subspace_iteration": 1, "omega_rank": 0}),
64-
"correlation_matrix": (self._corrm, {"rank": -1, "save_memory": False}),
59+
'svd': (self._svd, {
60+
'rank': -1
61+
}),
62+
'randomized_svd': (self._rsvd, {
63+
'rank': -1,
64+
'subspace_iteration': 1,
65+
'omega_rank': 0
66+
}),
67+
'correlation_matrix': (self._corrm, {
68+
'rank': -1,
69+
'save_memory': False
70+
}),
6571
}
6672

6773
self._modes = None
@@ -71,11 +77,9 @@ def __init__(self, method="svd", **kwargs):
7177
if method is None:
7278
raise RuntimeError(
7379
"Invalid method for POD. Please chose one among {}".format(
74-
", ".join(available_methods)
75-
)
76-
)
80+
', '.join(available_methods)))
7781

78-
self._method, args = method
82+
self.__method, args = method
7983
args.update(kwargs)
8084

8185
for hyperparam, value in args.items():
@@ -107,7 +111,7 @@ def fit(self, X):
107111
108112
:param numpy.ndarray X: the input snapshots matrix (stored by column)
109113
"""
110-
self._modes, self._singular_values = self._method(X)
114+
self._modes, self._singular_values = self.__method(X)
111115
return self
112116

113117
@task(returns=np.ndarray, target_direction=IN)
@@ -132,7 +136,8 @@ def inverse_transform(self, X, database):
132136
predicted_sol = self.modes.dot(X)
133137

134138
if database and database.scaler_snapshots:
135-
predicted_sol = database.scaler_snapshots.inverse_transform(predicted_sol.T).T
139+
predicted_sol = database.scaler_snapshots.inverse_transform(
140+
predicted_sol.T).T
136141

137142
if 1 in predicted_sol.shape:
138143
predicted_sol = predicted_sol.ravel()
@@ -176,7 +181,6 @@ def _truncation(self, X, s):
176181
:return: the number of modes
177182
:rtype: int
178183
"""
179-
180184
def omega(x):
181185
return 0.56 * x**3 - 0.95 * x**2 + 1.82 * x + 1.43
182186

@@ -227,7 +231,8 @@ def _rsvd(self, X):
227231
constructing approximate matrix decompositions. N. Halko, P. G.
228232
Martinsson, J. A. Tropp.
229233
"""
230-
if self.omega_rank == 0 and isinstance(self.rank, int) and self.rank not in [0, -1]:
234+
if (self.omega_rank == 0 and isinstance(self.rank, int)
235+
and self.rank not in [0, -1]):
231236
omega_rank = self.rank * 2
232237
elif self.omega_rank == 0:
233238
omega_rank = X.shape[1] * 2

0 commit comments

Comments
 (0)