Skip to content

Commit 25e14bb

Browse files
committed
More codacy compliance and doc improvements
1 parent 3a1486a commit 25e14bb

File tree

3 files changed

+111
-69
lines changed

3 files changed

+111
-69
lines changed

pydmd/costs.py

Lines changed: 68 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import matplotlib.pyplot as plt
1616
import xarray as xr
1717

18-
from .utils import compute_rank, compute_svd
1918
from pydmd.bopdmd import BOPDMD
19+
from .utils import compute_rank, compute_svd
2020

2121

2222
class COSTS:
@@ -123,6 +123,7 @@ def __init__(
123123
self._transform_method = None
124124
self._window_means_array = None
125125
self._non_integer_n_slide = None
126+
self._svd_rank_pre_allocate = None
126127

127128
# Specify default keywords to hand to BOPDMD.
128129
if pydmd_kwargs is None:
@@ -143,110 +144,124 @@ def __init__(
143144

144145
@property
145146
def svd_rank(self):
146-
"""
147+
"""Return the svd_rank used for the BOPDMD fit.
148+
147149
:return: the rank used for the svd truncation.
148150
:rtype: int or float
149151
"""
150152
return self._svd_rank
151153

152154
@property
153155
def global_svd(self):
154-
"""
156+
"""Return if a global svd projection basis was used.
157+
155158
:return: If a global svd was used for the BOP-DMD fit.
156159
:rtype: int or float
157160
"""
158161
return self._global_svd
159162

160163
@property
161164
def window_length(self):
162-
"""
165+
"""Return the window length used.
166+
163167
:return: the length of the windows used for this decomposition level.
164168
:rtype: int or float
165169
"""
166170
return self._window_length
167171

168172
@property
169173
def step_size(self):
170-
"""
174+
"""Return the step size between each window.
175+
171176
:return: the length of the windows used for this decomposition level.
172177
:rtype: int or float
173178
"""
174179
return self._step_size
175180

176181
@property
177182
def n_slides(self):
178-
"""
183+
"""Return the number of slides performed for each window.
184+
179185
:return: number of window slides for this decomposition level.
180186
:rtype: int
181187
"""
182188
return self._n_slides
183189

184190
@property
185191
def modes_array(self):
186-
"""
192+
"""Return the spatial modes of each window's fit.
193+
187194
:return: Modes for each window
188195
:rtype: numpy.ndarray
189196
"""
190197
return self._modes_array
191198

192199
@property
193200
def amplitudes_array(self):
194-
"""
201+
"""Return the amplitudes of each window's fit.
202+
195203
:return: amplitudes of each window
196204
:rtype: numpy.ndarray
197205
"""
198206
return self._amplitudes_array
199207

200208
@property
201209
def omega_array(self):
202-
"""
210+
"""Return the frequencies (omega) of each window's fit.
211+
203212
:return: omega (a.k.a eigenvalues or time dynamics) for each window
204213
:rtype: numpy.ndarray
205214
"""
206215
return self._omega_array
207216

208217
@property
209218
def time_array(self):
210-
"""
219+
"""Return the time values contained by each window.
220+
211221
:return: time values for each fit window
212222
:rtype: numpy.ndarray
213223
"""
214224
return self._time_array
215225

216226
@property
217227
def window_means_array(self):
218-
"""
228+
"""Return the array of window time means.
229+
219230
:return: Time mean of the data in each window
220231
:rtype: numpy.ndarray
221232
"""
222233
return self._window_means_array
223234

224235
@property
225236
def n_components(self):
226-
"""
237+
"""Return the number of frequency bands.
238+
227239
:return: Number of frequency bands fit in the kmeans clustering
228240
:rtype: int
229241
"""
230242
return self._n_components
231243

232244
@property
233245
def cluster_centroids(self):
234-
"""
246+
"""Return the frequency band centroids.
247+
235248
:return: Centroids of the frequency bands
236249
:rtype: numpy.ndarray
237250
"""
238251
return self._cluster_centroids
239252

240253
@property
241254
def omega_classes(self):
242-
"""
255+
"""Return the frequency band classifications.
256+
243257
:return: Frequency band classifications, corresponds to omega_array
244258
:rtype: numpy.ndarray
245259
"""
246260
return self._omega_classes
247261

248262
def periods(self):
249-
"""
263+
"""Convert the omega array into periods.
264+
250265
:return: Time dynamics converted to periods
251266
:rtype: numpy.ndarray
252267
"""
@@ -259,7 +274,15 @@ def periods(self):
259274

260275
@staticmethod
261276
def relative_error(x_est, x_true):
262-
"""Helper function for calculating the relative error."""
277+
"""Helper function for calculating the relative error.
278+
279+
:param x_est: Estimated values (i.e. from reconstruction)
280+
:type x_est: numpy.ndarray
281+
:param x_true: True (or observed) values.
282+
:type x_true: numpy.ndarray
283+
:return: Relative error between observations and model.
284+
:rtype: numpy.ndarray
285+
"""
263286
return np.linalg.norm(x_est - x_true) / np.linalg.norm(x_true)
264287

265288
@staticmethod
@@ -346,13 +369,29 @@ def build_kern(window_length):
346369

347370
@staticmethod
348371
def _data_shape(data):
349-
"""Returns the shape of the data."""
372+
"""Returns the shape of the data.
373+
374+
:param data: Data to fit with mrCOSTS.
375+
:type data: numpy.ndarray
376+
:return n_time_steps: Number of time steps.
377+
:rtype n_time_steps: int
378+
:return n_data_vars: Number of spatial variables.
379+
:rtype n_data_vars: int
380+
"""
350381
n_time_steps = np.shape(data)[1]
351382
n_data_vars = np.shape(data)[0]
352383
return n_time_steps, n_data_vars
353384

354385
def _build_proj_basis(self, data, svd_rank=None):
355-
"""Build the projection basis."""
386+
"""Build the projection basis.
387+
388+
:param data: Data to fit with mrCOSTS.
389+
:type data: numpy.ndarray
390+
:param svd_rank: Rank to fit with COSTS.
391+
:type svd_rank: int
392+
:return: SVD projection basis for COSTS.
393+
:rtype: numpy.ndarray
394+
"""
356395
self._svd_rank = compute_rank(data, svd_rank=svd_rank)
357396
# Recover the first r modes of the global svd
358397
return compute_svd(data, svd_rank=svd_rank)[0]
@@ -388,7 +427,8 @@ def _build_initizialization(self):
388427
and self._cluster_centroids is not None
389428
):
390429
raise ValueError(
391-
"Only one of `init_alpha` and `cluster_centroids` can be provided"
430+
"Only one of `init_alpha` and `cluster_centroids` can be"
431+
" provided"
392432
)
393433
# If not initial values are provided return None by default.
394434
else:
@@ -589,8 +629,7 @@ def get_window_indices(self, k):
589629
sample_start = self._step_size * k
590630
if k == self._n_slides - 1 and self._non_integer_n_slide:
591631
return slice(-self._window_length, None)
592-
else:
593-
return slice(sample_start, sample_start + self._window_length)
632+
return slice(sample_start, sample_start + self._window_length)
594633

595634
def cluster_omega(
596635
self,
@@ -826,10 +865,6 @@ def plot_omega_histogram(self):
826865
:return fig: Figure handle for the plot
827866
:return ax: Axes handle for the plot
828867
"""
829-
# Reshape the omega array into a 1d array
830-
omega_array = self.omega_array
831-
# n_slides = omega_array.shape[0]
832-
# svd_rank = omega_array.shape[1]
833868

834869
# Apply the transformation to omega
835870
omega_transform = self.transform_omega(
@@ -865,12 +900,6 @@ def plot_omega_time_series(self):
865900
fig, ax = plt.subplots(1, 1)
866901
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
867902

868-
# Reshape the omega array into a 1d array
869-
omega_array = self.omega_array
870-
# n_slides = omega_array.shape[0]
871-
# svd_rank = omega_array.shape[1]
872-
# omega_rshp = omega_array.reshape(n_slides * svd_rank)
873-
874903
# Apply the transformation to omega
875904
omega_transform = self.transform_omega(
876905
self.omega_array.flatten(), transform_method=self._transform_method
@@ -1360,7 +1389,10 @@ def plot_time_series(
13601389
for n in range(self.n_components):
13611390
ax = axes[n + 1]
13621391
if n == 0:
1363-
title = "blue = Low-frequency component, black = high frequency residual"
1392+
title = (
1393+
"blue = Low-frequency component, black = high "
1394+
"frequency residual"
1395+
)
13641396
ax.plot(xr_sep[n, space_index, :] - ground_truth_mean)
13651397
else:
13661398
title = "Band period = {:.0f} window length".format(
@@ -1386,7 +1418,10 @@ def plot_time_series(
13861418
label="Residual",
13871419
)
13881420
ax.set_title(
1389-
"black=input data, yellow=low-frequency, blue=high-frequency, red=residual"
1421+
(
1422+
"black=input data, yellow=low-frequency, "
1423+
"blue=high-frequency, red=residual"
1424+
)
13901425
)
13911426
else:
13921427
ax.set_title(

0 commit comments

Comments
 (0)