Skip to content

Commit 4d896cc

Browse files
klapomtezzele
authored andcommitted
Clustering method as callable.
- Default is MiniKMeansBatch in both mrCOSTS/COSTS - Adopted review suggestions - Added doc string - Cached class indexing for reconstruction.
1 parent 8be0bdf commit 4d896cc

File tree

4 files changed

+69
-88
lines changed

4 files changed

+69
-88
lines changed

pydmd/costs.py

Lines changed: 32 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pydmd.bopdmd import BOPDMD
33
from .utils import compute_rank, compute_svd
44
import copy
5-
from sklearn.cluster import KMeans
5+
from sklearn.cluster import MiniBatchKMeans
66
from sklearn.metrics import silhouette_score
77
import matplotlib.pyplot as plt
88
import xarray as xr
@@ -258,7 +258,6 @@ def build_windows(data, window_length, step_size, integer_windows=False):
258258
:type integer_windows: bool
259259
:return:
260260
"""
261-
262261
if integer_windows:
263262
n_split = np.floor(data.shape[1] / window_length).astype(int)
264263
else:
@@ -329,11 +328,9 @@ def _build_initizialization(self):
329328
:return: First guess of eigenvalues
330329
:rtype: numpy.ndarray or None
331330
"""
332-
# If not initial values are provided return None by default.
333-
init_alpha = None
334331
# User provided initial eigenvalues.
335332
if self._initialize_artificially and self._init_alpha is not None:
336-
init_alpha = self._init_alpha
333+
return self._init_alpha
337334
# Initial eigenvalue guesses from kmeans clustering.
338335
elif (
339336
self._initialize_artificially
@@ -347,6 +344,7 @@ def _build_initizialization(self):
347344
init_alpha = init_alpha * np.tile(
348345
[1, -1], int(self._svd_rank / self._n_components)
349346
)
347+
return init_alpha
350348
# The user accidentally provided both methods of initializing the eigenvalues.
351349
elif (
352350
self._initialize_artificially
@@ -356,8 +354,9 @@ def _build_initizialization(self):
356354
raise ValueError(
357355
"Only one of `init_alpha` and `cluster_centroids` can be provided"
358356
)
359-
360-
return init_alpha
357+
# If not initial values are provided return None by default.
358+
else:
359+
return None
361360

362361
def fit(
363362
self,
@@ -516,7 +515,9 @@ def fit(
516515
self._amplitudes_array[
517516
k, : optdmd.eigs.shape[0]
518517
] = optdmd.amplitudes
519-
self._window_means_array[k] = c.flatten()
518+
self._window_means_array[k] = np.mean(
519+
data_window, 1, keepdims=True
520+
).flatten()
520521
self._time_array[k] = original_time_window
521522

522523
# Reset optdmd between iterations
@@ -549,16 +550,14 @@ def get_window_indices(self, k):
549550
if k == self._n_slides - 1 and self._non_integer_n_slide:
550551
return slice(-self._window_length, None)
551552
else:
552-
return slice(
553-
sample_start, sample_start + self._window_length
554-
)
553+
return slice(sample_start, sample_start + self._window_length)
555554

556555
def cluster_omega(
557556
self,
558557
n_components,
559558
kmeans_kwargs=None,
560559
transform_method=None,
561-
method="KMeans",
560+
method=MiniBatchKMeans,
562561
):
563562
"""Clusters fitted eigenvalues into frequency bands by the imaginary component.
564563
@@ -567,35 +566,33 @@ def cluster_omega(
567566
:type method: str
568567
:param n_components: Hyperparameter for k-means clustering, number of clusters.
569568
:type n_components: int
570-
:param kmeans_kwargs: Arguments for KMeans clustering.
569+
:param kmeans_kwargs: Arguments for KMeans clustering. The default is
570+
random_state = 0.
571571
:type kmeans_kwargs: dict
572572
:param transform_method: How to transform omega. See docstring for valid options.
573573
:type transform_method: str or NoneType
574574
:return:
575575
"""
576576
# Reshape the omega array into a 1d array
577-
omega_array = self.omega_array
578-
n_slides = omega_array.shape[0]
579-
svd_rank = omega_array.shape[1]
580-
omega_rshp = omega_array.reshape(n_slides * svd_rank)
577+
n_slides = self.omega_array.shape[0]
578+
svd_rank = self.omega_array.shape[1]
579+
omega_rshp = self.omega_array.reshape(n_slides * svd_rank)
581580
omega_transform = self.transform_omega(
582581
omega_rshp, transform_method=transform_method
583582
)
584583

585584
if kmeans_kwargs is None:
585+
kmeans_kwargs = {}
586586
random_state = 0
587-
kmeans_kwargs = {
588-
"random_state": random_state,
589-
}
590-
if method == "KMeans":
591-
clustering = KMeans(n_clusters=n_components, **kmeans_kwargs)
592-
elif method == "KMediods":
593-
from sklearn_extra.cluster import KMedoids
594-
595-
clustering = KMedoids(n_clusters=n_components, **kmeans_kwargs)
596-
else:
587+
kmeans_kwargs["random_state"] = kmeans_kwargs.get(
588+
"random_state", random_state
589+
)
590+
clustering = method(n_clusters=n_components, **kmeans_kwargs)
591+
if not hasattr(clustering, "fit_predict") and callable(
592+
getattr(clustering, "fit_predict")
593+
):
597594
raise ValueError(
598-
"Unrecognized clustering method {}.".format(method)
595+
"Clustering method must have `fit_predict()` method."
599596
)
600597

601598
omega_classes = clustering.fit_predict(np.atleast_2d(omega_transform).T)
@@ -615,9 +612,7 @@ def cluster_omega(
615612
self._transform_method = transform_method
616613
self._n_components = n_components
617614

618-
return self
619-
620-
def transform_omega(self, omega_array, transform_method=None):
615+
def transform_omega(self, omega_array, transform_method="absolute"):
621616
"""Transform omega, primarily for clustering.
622617
Options for transforming omega are:
623618
"period": :math:`\\frac{1}{\\omega}`
@@ -633,18 +628,14 @@ def transform_omega(self, omega_array, transform_method=None):
633628
:rtype: numpy.ndarray
634629
"""
635630
# Apply a transformation to omega to improve frequency band separation
636-
if transform_method is None or transform_method == "absolute":
631+
if transform_method == "absolute":
637632
omega_transform = np.abs(omega_array.imag.astype("float"))
638633
self._omega_label = r"$|\omega|$"
639634
self._hist_kwargs = {"bins": 64}
640635
# Outstanding question: should this be the complex conjugate or
641636
# the imaginary component squared?
642637
elif transform_method == "square_frequencies":
643-
# omega_transform = (np.conj(omega_array) * omega_array).real.astype(
644-
# "float"
645-
# )
646638
omega_transform = (omega_array.imag**2).real.astype("float")
647-
648639
self._omega_label = r"$|\omega|^{2}$"
649640
self._hist_kwargs = {"bins": 64}
650641
elif transform_method == "log10":
@@ -654,17 +645,9 @@ def transform_omega(self, omega_array, transform_method=None):
654645
omega_transform[~np.isfinite(omega_transform)] = zero_imputer
655646
self._omega_label = r"$log_{10}(|\omega|)$"
656647
self._hist_kwargs = {"bins": 64}
657-
# {
658-
# "bins": np.linspace(
659-
# np.min(np.log10(omega_transform[omega_transform > 0])),
660-
# np.max(np.log10(omega_transform[omega_transform > 0])),
661-
# 64,
662-
# )
663-
# }
664648
elif transform_method == "period":
665649
omega_transform = 1 / np.abs(omega_array.imag.astype("float"))
666650
self._omega_label = "Period"
667-
# @ToDo: Specify bins like in log10 transform
668651
self._hist_kwargs = {"bins": 64}
669652
else:
670653
raise ValueError(
@@ -716,7 +699,7 @@ def cluster_hyperparameter_sweep(
716699
)
717700

718701
for nind, n in enumerate(n_components_range):
719-
_ = self.cluster_omega(
702+
self.cluster_omega(
720703
n_components=n, transform_method=transform_method
721704
)
722705

@@ -872,11 +855,12 @@ def scale_reconstruction(
872855
(self._n_components, self._n_data_vars, self._window_length)
873856
)
874857
for j in np.unique(self._omega_classes):
858+
class_index = classification == j
875859
xr_sep_window[j] = np.linalg.multi_dot(
876860
[
877-
w[:, classification == j],
878-
np.diag(b[classification == j]),
879-
np.exp(omega[classification == j] * t),
861+
w[:, class_index],
862+
np.diag(b[class_index]),
863+
np.exp(omega[class_index] * t),
880864
]
881865
).real
882866

@@ -1222,7 +1206,6 @@ def plot_time_series(
12221206
scale_reconstruction_kwargs = {}
12231207
xr_sep = self.scale_reconstruction(**scale_reconstruction_kwargs)
12241208

1225-
# ToDo: Make these kwargs adjustable inputs.
12261209
fig, axes = plt.subplots(
12271210
nrows=self.n_components + 2,
12281211
sharex=True,

pydmd/mrcosts.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,7 @@ def n_components_global(self):
178178
# @ToDo: Use the class variable instead of passing it around
179179
@property
180180
def omega_classes_interpolated(self):
181-
"""
182-
183-
Note, this returns the multi-resolution interpolation of omega classes.
181+
"""Returns the multi-resolution interpolation of omega classes
184182
185183
:return: Ints for each omega value indicating which cluster it belongs to.
186184
:rtype: list of numpy.ndarray
@@ -191,11 +189,9 @@ def omega_classes_interpolated(self):
191189

192190
@property
193191
def ragged_omega_classes(self):
194-
"""
192+
"""Omega classes for each decomposition level after global clustering.
195193
196-
Note, this returns a list of ragged numpy arrays.
197-
198-
:return: Ints for each omega value indicating which cluster it belongs to.
194+
:return: list of classes for each omega value for each decomposition level.
199195
:rtype: list of numpy.ndarray
200196
"""
201197
if self._omega_classes is None:
@@ -204,7 +200,8 @@ def ragged_omega_classes(self):
204200

205201
@property
206202
def ragged_omega_array(self):
207-
"""
203+
"""Omega values for each decomposition level.
204+
208205
:return: list of omega arrays for each decomposition level.
209206
:rtype: list of numpy.ndarray
210207
"""
@@ -216,7 +213,8 @@ def ragged_omega_array(self):
216213

217214
@property
218215
def ragged_modes_array(self):
219-
"""
216+
"""Modes for each decomposition level.
217+
220218
:return: list of modes arrays for each decomposition level.
221219
:rtype: list of numpy.ndarray
222220
"""
@@ -228,7 +226,8 @@ def ragged_modes_array(self):
228226

229227
@property
230228
def ragged_amplitudes_array(self):
231-
"""
229+
"""Amplitudes for each decomposition level.
230+
232231
:return: list of amplitudes arrays for each decomposition level.
233232
:rtype: list of numpy.ndarray
234233
"""
@@ -240,7 +239,8 @@ def ragged_amplitudes_array(self):
240239

241240
@staticmethod
242241
def _data_shape(data):
243-
"""
242+
"""Give the data shape.
243+
244244
:return: Shape of the data for fitting.
245245
:rtype: Tuple of ints
246246
"""
@@ -695,7 +695,6 @@ def plot_local_time_series(
695695
fig, axes = self.costs_array[level].plot_time_series(
696696
space_index,
697697
x_iter,
698-
# plot_kwargs=plot_kwargs,
699698
scale_reconstruction_kwargs=scale_reconstruction_kwargs,
700699
)
701700

@@ -707,7 +706,8 @@ def global_cluster_hyperparameter_sweep(
707706
transform_method=None,
708707
score_method=None,
709708
verbose=True,
710-
method=None,
709+
method=MiniBatchKMeans,
710+
kmeans_kwargs=None,
711711
):
712712
"""
713713
Hyperparameter search for n_components for kmeans clustering.
@@ -741,11 +741,9 @@ def global_cluster_hyperparameter_sweep(
741741
n_components=n,
742742
transform_method=transform_method,
743743
method=method,
744+
kmeans_kwargs=kmeans_kwargs,
744745
)
745746

746-
if verbose:
747-
print("scoring")
748-
print(np.unique(omega_classes.reshape(-1, 1)))
749747
if score_method is None or score_method == "silhouette":
750748
score[nind] = silhouette_score(
751749
omega.reshape(-1, 1),
@@ -770,7 +768,7 @@ def global_cluster_omega(
770768
n_components=None,
771769
transform_method=None,
772770
kmeans_kwargs=None,
773-
method="KMeans",
771+
method=MiniBatchKMeans,
774772
):
775773
"""Performs frequency band clustering on the global distribution of omega.
776774
@@ -785,6 +783,8 @@ def global_cluster_omega(
785783
Default value is "absolute". All transformations and clustering are performed on
786784
the imaginary portion of omega.
787785
786+
:param method: Clustering method following the sklearn pattern (has `fit_predict`)
787+
and `n_clusters` keyword.
788788
:param n_components: The number of clusters to find.
789789
:type n_components: int
790790
:param transform_method: How to transform omega. See docstring for valid options.
@@ -820,21 +820,17 @@ def global_cluster_omega(
820820
)
821821

822822
if kmeans_kwargs is None:
823+
kmeans_kwargs = {}
823824
random_state = 0
824-
kmeans_kwargs = {
825-
"random_state": random_state,
826-
}
827-
if method == "KMeans":
828-
clustering = MiniBatchKMeans(
829-
n_clusters=n_components, **kmeans_kwargs
825+
kmeans_kwargs["random_state"] = kmeans_kwargs.get(
826+
"random_state", random_state
830827
)
831-
elif method == "KMediods":
832-
from sklearn_extra.cluster import KMedoids
833-
834-
clustering = KMedoids(n_clusters=n_components, **kmeans_kwargs)
835-
else:
828+
clustering = method(n_clusters=n_components, **kmeans_kwargs)
829+
if not hasattr(clustering, "fit_predict") and callable(
830+
getattr(clustering, "fit_predict")
831+
):
836832
raise ValueError(
837-
"Unrecognized clustering method {}.".format(method)
833+
"Clustering method must have `fit_predict()` method."
838834
)
839835

840836
omega_classes = clustering.fit_predict(np.atleast_2d(omega_array).T)
@@ -874,7 +870,6 @@ def transform_omega(omega_array, transform_method=None):
874870
if transform_method is None or transform_method == "absolute":
875871
omega_array = np.abs(omega_array.imag.astype("float"))
876872
elif transform_method == "square_frequencies":
877-
# omega_array = (np.conj(omega_array) * omega_array).astype("float")
878873
omega_array = (omega_array.imag**2).real.astype("float")
879874
elif transform_method == "period":
880875
omega_array = 1 / np.abs(omega_array.imag.astype("float"))
@@ -963,11 +958,12 @@ def global_scale_reconstruction(
963958
)
964959
)
965960
for j in np.arange(0, self._n_components_global):
961+
class_ind = classification == j
966962
xr_sep_window[j, :, :] = np.linalg.multi_dot(
967963
[
968-
w[:, classification == j],
969-
np.diag(b[classification == j]),
970-
np.exp(omega[classification == j] * t),
964+
w[:, class_ind],
965+
np.diag(b[class_ind]),
966+
np.exp(omega[class_ind] * t),
971967
]
972968
)
973969

tests/test_costs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99
def overlapping_oscillators():
1010
"""
11-
Given a time vector t_eval = t1, t2, ..., evaluates and returns
12-
the snapshots z(t1), z(t2), ... as columns of the matrix Z.
13-
Simulates data z given by the system of ODEs
14-
z' = Az
15-
where A = [1 -2; 1 -1] and z_0 = [1, 0.1].
11+
Simulates a system with two oscillators with ocasionally overlapping
12+
frequencies. This example was adapted from Dylewsky et al., 2019.
13+
14+
Oscillator #1: FitzHugh-Nagumo Model
15+
Oscillator #2: Unforced Duffing Oscillator
1616
"""
1717

1818
def rhs_FNM(t, x, tau, a, b, Iext):

0 commit comments

Comments
 (0)