2222class COSTS :
2323 """Coherent Spatio-Temporal Scale Separation with DMD.
2424
25- :param window_length: Length of the analysis window in number of time steps.
26- :type window_length: int
27- :param step_size: Number of time steps to slide each CSM-DMD window.
28- :type step_size: int
2925 :param n_components: Number of independent frequency bands for this
3026 window length.
3127 :type n_components: int
@@ -60,10 +56,6 @@ class COSTS:
6056 :param max_rank: Maximum svd_rank allowed when the svd_rank is found
6157 through rank truncation (i.e., svd_rank=0).
6258 :type max_rank: int
63- :param use_kmean_freqs: Flag specifying if the BOPDMD fit should use
64- initial values taken from cluster centroids, e.g., from a previoius
65- iteration.
66- :type use_kmean_freqs: bool
6759 :param init_alpha: Initial guess for the eigenvalues provided to BOPDMD.
6860 Must be equal to the `svd_rank`.
6961 :type init_alpha: numpy array
@@ -286,7 +278,7 @@ def relative_error(x_est, x_true):
286278 return np .linalg .norm (x_est - x_true ) / np .linalg .norm (x_true )
287279
288280 @staticmethod
289- def build_windows (data , window_length , step_size , integer_windows = False ):
281+ def _build_windows (data , window_length , step_size , integer_windows = False ):
290282 """How many times integer slides fit the data for a given step and
291283 window size.
292284
@@ -396,7 +388,7 @@ def _build_proj_basis(self, data, svd_rank=None):
396388 # Recover the first r modes of the global svd
397389 return compute_svd (data , svd_rank = svd_rank )[0 ]
398390
399- def _build_initizialization (self ):
391+ def _build_initialization (self ):
400392 """Method for making initial guess of DMD eigenvalues.
401393
402394 :return: First guess of eigenvalues
@@ -421,7 +413,7 @@ def _build_initizialization(self):
421413 return init_alpha
422414 # The user accidentally provided both methods of initializing the
423415 # eigenvalues.
424- elif (
416+ if (
425417 self ._initialize_artificially
426418 and self ._init_alpha is not None
427419 and self ._cluster_centroids is not None
@@ -430,9 +422,10 @@ def _build_initizialization(self):
430422 "Only one of `init_alpha` and `cluster_centroids` can be"
431423 " provided"
432424 )
433- # If not initial values are provided return None by default.
434- else :
435- return None
425+
426+ # In all other cases we return None and let the first iteration of
427+ # BOPDMD searches for the initial values.
428+ return None
436429
437430 def fit (
438431 self ,
@@ -450,9 +443,11 @@ def fit(
450443 :type data: numpy.ndarray
451444 :param time: time series labeling the 1D snapshots
452445 :type time: numpy.ndarray
453- :param window_length: decomposition window length
446+ :param window_length: decomposition window length in number of time
447+ steps.
454448 :type window_length: int
455- :param step_size: how far to slide each window from the previous window.
449+ :param step_size: Number of time steps to slide forward from the
450+ previous window.
456451 :type step_size: int
457452 :param verbose: notifies progress for fitting. Default is False.
458453 :type verbose: bool
@@ -464,7 +459,7 @@ def fit(
464459 self ._window_length = window_length
465460 self ._step_size = step_size
466461 self ._n_time_steps , self ._n_data_vars = self ._data_shape (data )
467- self ._n_slides = self .build_windows (
462+ self ._n_slides = self ._build_windows (
468463 data ,
469464 self ._window_length ,
470465 self ._step_size ,
@@ -473,8 +468,9 @@ def fit(
473468 if self ._window_length > self ._n_time_steps :
474469 raise ValueError (
475470 (
476- "Window length ({}) is larger than the time dimension ({})"
477- ).format (self ._window_length , self ._n_time_steps )
471+ f"Window length ({ self ._window_length } ) is larger than the "
472+ f"time dimension ({ self ._n_time_steps } )"
473+ )
478474 )
479475
480476 # If the window size and step size do not span the data in an integer
@@ -535,7 +531,7 @@ def fit(
535531 self ._window_means_array = np .zeros ((self ._n_slides , self ._n_data_vars ))
536532
537533 # Get initial values for the eigenvalues.
538- self ._init_alpha = self ._build_initizialization ()
534+ self ._init_alpha = self ._build_initialization ()
539535
540536 # Initialize the BOPDMD object.
541537 optdmd = BOPDMD (
@@ -554,7 +550,7 @@ def fit(
554550 # Perform the sliding window DMD fitting.
555551 for k in range (self ._n_slides ):
556552 if verbose and k % 50 == 0 :
557- print ("{ } of {}" . format ( k , self ._n_slides ) )
553+ print (f" { k : } of { self ._n_slides : } " )
558554
559555 sample_slice = self .get_window_indices (k )
560556 data_window = data [:, sample_slice ]
@@ -586,7 +582,7 @@ def fit(
586582 optdmd .svd_rank = min (_svd_rank , self ._max_rank )
587583 else :
588584 optdmd .svd_rank = _svd_rank
589- optdmd ._proj_basis = self ._pydmd_kwargs ["proj_basis" ]
585+ optdmd .proj_basis = self ._pydmd_kwargs ["proj_basis" ]
590586
591587 # Fit the window using the optDMD.
592588 optdmd .fit (data_window , time_window )
@@ -681,6 +677,8 @@ def _cluster(
681677 """Clusters fitted eigenvalues into frequency bands by the imaginary
682678 component.
683679
680+ Helper function for clustering. Call `cluster_omega` instead.
681+
684682 :param n_components: Hyperparameter for k-means clustering, number of
685683 clusters.
686684 :type n_components: int
@@ -736,10 +734,7 @@ def _cluster(
736734 omega_classes = lut [omega_classes ]
737735 cluster_centroids = cluster_centroids [idx ]
738736
739- return (
740- cluster_centroids ,
741- omega_classes ,
742- )
737+ return cluster_centroids , omega_classes
743738
744739 def transform_omega (self , omega_array , transform_method = "absolute" ):
745740 """Transform omega, primarily for clustering.
@@ -780,7 +775,7 @@ def transform_omega(self, omega_array, transform_method="absolute"):
780775 self ._hist_kwargs = {"bins" : 64 }
781776 else :
782777 raise ValueError (
783- "Transform method {} not supported." . format ( transform_method )
778+ f "Transform method { transform_method : } not supported."
784779 )
785780
786781 return omega_transform
@@ -790,7 +785,7 @@ def cluster_hyperparameter_sweep(
790785 n_components_range = None ,
791786 transform_method = None ,
792787 method = MiniBatchKMeans ,
793- kneans_kwargs = None ,
788+ clustering_kwargs = None ,
794789 ):
795790 """Hyperparameter search for number of frequency bands.
796791
@@ -815,6 +810,8 @@ def cluster_hyperparameter_sweep(
815810 :param method: Clustering method following the sklearn pattern (has
816811 `fit_predict` and `n_clusters` keywords). Default is
817812 MiniBatchKMeans.
813+ :param clustering_kwargs: keywords to give to the clustering method.
814+ :type clustering_kwargs: dict
818815 :type method: method
819816 :return: optimal value of `n_components` for clustering.
820817 """
@@ -837,14 +834,14 @@ def cluster_hyperparameter_sweep(
837834 )
838835
839836 for nind , n in enumerate (n_components_range ):
840- self ._cluster (
837+ _ , omega_classes = self ._cluster (
841838 n_components = n ,
842839 transform_method = transform_method ,
843- kmeans_kwargs = kneans_kwargs ,
840+ kmeans_kwargs = clustering_kwargs ,
844841 method = method ,
845842 )
846843
847- classes_reshape = self . omega_classes .reshape (
844+ classes_reshape = omega_classes .reshape (
848845 self ._n_slides * self ._svd_rank_pre_allocate
849846 )
850847
@@ -1121,8 +1118,9 @@ def plot_scale_separation(
11211118 if plot_contours :
11221119 ax .contour (data , colors = ["k" ])
11231120 ax .set_title (
1124- "Input Data at decomposition window length = {}" .format (
1125- self ._window_length
1121+ (
1122+ f"Input data at decomposition window length ="
1123+ f" { self ._window_length :} "
11261124 )
11271125 )
11281126 ax .set_ylabel ("Space (-)" )
@@ -1238,9 +1236,7 @@ def plot_reconstructions(
12381236 ax .set_ylabel ("Space (-)" )
12391237 ax .set_xlabel ("Time (-)" )
12401238 ax .set_title (
1241- "Input Data at decomposition window length = {}" .format (
1242- self ._window_length
1243- )
1239+ f"Input Data at decomposition window length = { self ._window_length } "
12441240 )
12451241 for n_cluster , cluster in enumerate (self ._cluster_centroids ):
12461242 if plot_period :
@@ -1321,7 +1317,7 @@ def plot_error(
13211317 ax_glbl_r .set_xlabel ("time (-)" )
13221318 ax_glbl_r .set_ylabel ("space (-)" )
13231319 re = self .relative_error (global_reconstruction .real , data )
1324- ax_glbl_r .set_title ("Error in Global Reconstruction = {:.2}" . format ( re ) )
1320+ ax_glbl_r .set_title (f "Error in Global Reconstruction = { re :.2} " )
13251321
13261322 def plot_time_series (
13271323 self ,
@@ -1373,9 +1369,8 @@ def plot_time_series(
13731369 lw = 0.5 ,
13741370 )
13751371 ax .set_title (
1376- "window={}, black=input data, red=reconstruction" .format (
1377- self ._window_length
1378- )
1372+ f"window={ self ._window_length } , black=input data, "
1373+ f"red=reconstruction"
13791374 )
13801375 ax .set_ylabel ("Amp." )
13811376 ax .set_xlabel ("Time" )
@@ -1395,9 +1390,8 @@ def plot_time_series(
13951390 )
13961391 ax .plot (xr_sep [n , space_index , :] - ground_truth_mean )
13971392 else :
1398- title = "Band period = {:.0f} window length" .format (
1399- 2 * np .pi / self .cluster_centroids [n ]
1400- )
1393+ period = 2 * np .pi / self .cluster_centroids [n ]
1394+ title = f"Band period = { period :.0f} window length"
14011395 ax .plot (xr_sep [n , space_index , :])
14021396 ax .set_title (title )
14031397 ax .set_ylabel ("Amp." )
@@ -1505,9 +1499,7 @@ def to_xarray(self):
15051499 )
15061500
15071501 for kw , kw_val in self ._pydmd_kwargs .items ():
1508- ds .attrs ["pydmd_kwargs__{}" .format (kw )] = self ._xarray_sanitize (
1509- kw_val
1510- )
1502+ ds .attrs [f"pydmd_kwargs__{ kw } " ] = self ._xarray_sanitize (kw_val )
15111503
15121504 return ds
15131505
0 commit comments