Skip to content

Commit 8c99d05

Browse files
klapomtezzele
authored andcommitted
mrCOSTS clustering follows COSTS logic
- mrCOSTS clustering now follows the COSTS logic with a dedicated _cluster helper and the global_cluster() and global_cluster_hyperparameter_sweep() calling this function instead. - Most of the plotting routines were given a data prep helper function to cut down on repeated code. - Most of the low-hanging codacy compliance changes. - Updated real data tutorial to use new mrCOSTS default values and updated some text.
1 parent 2159d24 commit 8c99d05

11 files changed

+372
-305
lines changed

pydmd/bopdmd.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,10 @@ def init_alpha(self):
10391039

10401040
@init_alpha.setter
10411041
def init_alpha(self, value):
1042+
"""Set a new initial eigenvalue guess.
1043+
1044+
:param value: The new eigenvalue guess.
1045+
"""
10421046
self._init_alpha = value
10431047

10441048
@property
@@ -1055,6 +1059,14 @@ def proj_basis(self):
10551059
raise RuntimeError(msg)
10561060
return self._proj_basis
10571061

1062+
@proj_basis.setter
1063+
def proj_basis(self, new_proj_basis):
1064+
"""Set a new projection basis.
1065+
1066+
:param new_proj_basis: The new projection basis to assign.
1067+
"""
1068+
self._proj_basis = new_proj_basis
1069+
10581070
@property
10591071
def num_trials(self):
10601072
"""

pydmd/costs.py

Lines changed: 38 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@
2222
class COSTS:
2323
"""Coherent Spatio-Temporal Scale Separation with DMD.
2424
25-
:param window_length: Length of the analysis window in number of time steps.
26-
:type window_length: int
27-
:param step_size: Number of time steps to slide each CSM-DMD window.
28-
:type step_size: int
2925
:param n_components: Number of independent frequency bands for this
3026
window length.
3127
:type n_components: int
@@ -60,10 +56,6 @@ class COSTS:
6056
:param max_rank: Maximum svd_rank allowed when the svd_rank is found
6157
through rank truncation (i.e., svd_rank=0).
6258
:type max_rank: int
63-
:param use_kmean_freqs: Flag specifying if the BOPDMD fit should use
64-
initial values taken from cluster centroids, e.g., from a previoius
65-
iteration.
66-
:type use_kmean_freqs: bool
6759
:param init_alpha: Initial guess for the eigenvalues provided to BOPDMD.
6860
Must be equal to the `svd_rank`.
6961
:type init_alpha: numpy array
@@ -286,7 +278,7 @@ def relative_error(x_est, x_true):
286278
return np.linalg.norm(x_est - x_true) / np.linalg.norm(x_true)
287279

288280
@staticmethod
289-
def build_windows(data, window_length, step_size, integer_windows=False):
281+
def _build_windows(data, window_length, step_size, integer_windows=False):
290282
"""How many times integer slides fit the data for a given step and
291283
window size.
292284
@@ -396,7 +388,7 @@ def _build_proj_basis(self, data, svd_rank=None):
396388
# Recover the first r modes of the global svd
397389
return compute_svd(data, svd_rank=svd_rank)[0]
398390

399-
def _build_initizialization(self):
391+
def _build_initialization(self):
400392
"""Method for making initial guess of DMD eigenvalues.
401393
402394
:return: First guess of eigenvalues
@@ -421,7 +413,7 @@ def _build_initizialization(self):
421413
return init_alpha
422414
# The user accidentally provided both methods of initializing the
423415
# eigenvalues.
424-
elif (
416+
if (
425417
self._initialize_artificially
426418
and self._init_alpha is not None
427419
and self._cluster_centroids is not None
@@ -430,9 +422,10 @@ def _build_initizialization(self):
430422
"Only one of `init_alpha` and `cluster_centroids` can be"
431423
" provided"
432424
)
433-
# If not initial values are provided return None by default.
434-
else:
435-
return None
425+
426+
# In all other cases we return None and let the first iteration of
427+
# BOPDMD searches for the initial values.
428+
return None
436429

437430
def fit(
438431
self,
@@ -450,9 +443,11 @@ def fit(
450443
:type data: numpy.ndarray
451444
:param time: time series labeling the 1D snapshots
452445
:type time: numpy.ndarray
453-
:param window_length: decomposition window length
446+
:param window_length: decomposition window length in number of time
447+
steps.
454448
:type window_length: int
455-
:param step_size: how far to slide each window from the previous window.
449+
:param step_size: Number of time steps to slide forward from the
450+
previous window.
456451
:type step_size: int
457452
:param verbose: notifies progress for fitting. Default is False.
458453
:type verbose: bool
@@ -464,7 +459,7 @@ def fit(
464459
self._window_length = window_length
465460
self._step_size = step_size
466461
self._n_time_steps, self._n_data_vars = self._data_shape(data)
467-
self._n_slides = self.build_windows(
462+
self._n_slides = self._build_windows(
468463
data,
469464
self._window_length,
470465
self._step_size,
@@ -473,8 +468,9 @@ def fit(
473468
if self._window_length > self._n_time_steps:
474469
raise ValueError(
475470
(
476-
"Window length ({}) is larger than the time dimension ({})"
477-
).format(self._window_length, self._n_time_steps)
471+
f"Window length ({self._window_length}) is larger than the "
472+
f"time dimension ({self._n_time_steps})"
473+
)
478474
)
479475

480476
# If the window size and step size do not span the data in an integer
@@ -535,7 +531,7 @@ def fit(
535531
self._window_means_array = np.zeros((self._n_slides, self._n_data_vars))
536532

537533
# Get initial values for the eigenvalues.
538-
self._init_alpha = self._build_initizialization()
534+
self._init_alpha = self._build_initialization()
539535

540536
# Initialize the BOPDMD object.
541537
optdmd = BOPDMD(
@@ -554,7 +550,7 @@ def fit(
554550
# Perform the sliding window DMD fitting.
555551
for k in range(self._n_slides):
556552
if verbose and k % 50 == 0:
557-
print("{} of {}".format(k, self._n_slides))
553+
print(f"{k:} of {self._n_slides:}")
558554

559555
sample_slice = self.get_window_indices(k)
560556
data_window = data[:, sample_slice]
@@ -586,7 +582,7 @@ def fit(
586582
optdmd.svd_rank = min(_svd_rank, self._max_rank)
587583
else:
588584
optdmd.svd_rank = _svd_rank
589-
optdmd._proj_basis = self._pydmd_kwargs["proj_basis"]
585+
optdmd.proj_basis = self._pydmd_kwargs["proj_basis"]
590586

591587
# Fit the window using the optDMD.
592588
optdmd.fit(data_window, time_window)
@@ -681,6 +677,8 @@ def _cluster(
681677
"""Clusters fitted eigenvalues into frequency bands by the imaginary
682678
component.
683679
680+
Helper function for clustering. Call `cluster_omega` instead.
681+
684682
:param n_components: Hyperparameter for k-means clustering, number of
685683
clusters.
686684
:type n_components: int
@@ -736,10 +734,7 @@ def _cluster(
736734
omega_classes = lut[omega_classes]
737735
cluster_centroids = cluster_centroids[idx]
738736

739-
return (
740-
cluster_centroids,
741-
omega_classes,
742-
)
737+
return cluster_centroids, omega_classes
743738

744739
def transform_omega(self, omega_array, transform_method="absolute"):
745740
"""Transform omega, primarily for clustering.
@@ -780,7 +775,7 @@ def transform_omega(self, omega_array, transform_method="absolute"):
780775
self._hist_kwargs = {"bins": 64}
781776
else:
782777
raise ValueError(
783-
"Transform method {} not supported.".format(transform_method)
778+
f"Transform method {transform_method:} not supported."
784779
)
785780

786781
return omega_transform
@@ -790,7 +785,7 @@ def cluster_hyperparameter_sweep(
790785
n_components_range=None,
791786
transform_method=None,
792787
method=MiniBatchKMeans,
793-
kneans_kwargs=None,
788+
clustering_kwargs=None,
794789
):
795790
"""Hyperparameter search for number of frequency bands.
796791
@@ -815,6 +810,8 @@ def cluster_hyperparameter_sweep(
815810
:param method: Clustering method following the sklearn pattern (has
816811
`fit_predict` and `n_clusters` keywords). Default is
817812
MiniBatchKMeans.
813+
:param clustering_kwargs: keywords to give to the clustering method.
814+
:type clustering_kwargs: dict
818815
:type method: method
819816
:return: optimal value of `n_components` for clustering.
820817
"""
@@ -837,14 +834,14 @@ def cluster_hyperparameter_sweep(
837834
)
838835

839836
for nind, n in enumerate(n_components_range):
840-
self._cluster(
837+
_, omega_classes = self._cluster(
841838
n_components=n,
842839
transform_method=transform_method,
843-
kmeans_kwargs=kneans_kwargs,
840+
kmeans_kwargs=clustering_kwargs,
844841
method=method,
845842
)
846843

847-
classes_reshape = self.omega_classes.reshape(
844+
classes_reshape = omega_classes.reshape(
848845
self._n_slides * self._svd_rank_pre_allocate
849846
)
850847

@@ -1121,8 +1118,9 @@ def plot_scale_separation(
11211118
if plot_contours:
11221119
ax.contour(data, colors=["k"])
11231120
ax.set_title(
1124-
"Input Data at decomposition window length = {}".format(
1125-
self._window_length
1121+
(
1122+
f"Input data at decomposition window length ="
1123+
f" {self._window_length:}"
11261124
)
11271125
)
11281126
ax.set_ylabel("Space (-)")
@@ -1238,9 +1236,7 @@ def plot_reconstructions(
12381236
ax.set_ylabel("Space (-)")
12391237
ax.set_xlabel("Time (-)")
12401238
ax.set_title(
1241-
"Input Data at decomposition window length = {}".format(
1242-
self._window_length
1243-
)
1239+
f"Input Data at decomposition window length = {self._window_length}"
12441240
)
12451241
for n_cluster, cluster in enumerate(self._cluster_centroids):
12461242
if plot_period:
@@ -1321,7 +1317,7 @@ def plot_error(
13211317
ax_glbl_r.set_xlabel("time (-)")
13221318
ax_glbl_r.set_ylabel("space (-)")
13231319
re = self.relative_error(global_reconstruction.real, data)
1324-
ax_glbl_r.set_title("Error in Global Reconstruction = {:.2}".format(re))
1320+
ax_glbl_r.set_title(f"Error in Global Reconstruction = {re:.2}")
13251321

13261322
def plot_time_series(
13271323
self,
@@ -1373,9 +1369,8 @@ def plot_time_series(
13731369
lw=0.5,
13741370
)
13751371
ax.set_title(
1376-
"window={}, black=input data, red=reconstruction".format(
1377-
self._window_length
1378-
)
1372+
f"window={self._window_length}, black=input data, "
1373+
f"red=reconstruction"
13791374
)
13801375
ax.set_ylabel("Amp.")
13811376
ax.set_xlabel("Time")
@@ -1395,9 +1390,8 @@ def plot_time_series(
13951390
)
13961391
ax.plot(xr_sep[n, space_index, :] - ground_truth_mean)
13971392
else:
1398-
title = "Band period = {:.0f} window length".format(
1399-
2 * np.pi / self.cluster_centroids[n]
1400-
)
1393+
period = 2 * np.pi / self.cluster_centroids[n]
1394+
title = f"Band period = {period:.0f} window length"
14011395
ax.plot(xr_sep[n, space_index, :])
14021396
ax.set_title(title)
14031397
ax.set_ylabel("Amp.")
@@ -1505,9 +1499,7 @@ def to_xarray(self):
15051499
)
15061500

15071501
for kw, kw_val in self._pydmd_kwargs.items():
1508-
ds.attrs["pydmd_kwargs__{}".format(kw)] = self._xarray_sanitize(
1509-
kw_val
1510-
)
1502+
ds.attrs[f"pydmd_kwargs__{kw}"] = self._xarray_sanitize(kw_val)
15111503

15121504
return ds
15131505

0 commit comments

Comments
 (0)