1515import matplotlib .pyplot as plt
1616import xarray as xr
1717
18- from .utils import compute_rank , compute_svd
1918from pydmd .bopdmd import BOPDMD
19+ from .utils import compute_rank , compute_svd
2020
2121
2222class COSTS :
@@ -123,6 +123,7 @@ def __init__(
123123 self ._transform_method = None
124124 self ._window_means_array = None
125125 self ._non_integer_n_slide = None
126+ self ._svd_rank_pre_allocate = None
126127
127128 # Specify default keywords to hand to BOPDMD.
128129 if pydmd_kwargs is None :
@@ -143,110 +144,124 @@ def __init__(
143144
144145 @property
145146 def svd_rank (self ):
146- """
147+ """Return the svd_rank used for the BOPDMD fit.
148+
147149 :return: the rank used for the svd truncation.
148150 :rtype: int or float
149151 """
150152 return self ._svd_rank
151153
152154 @property
153155 def global_svd (self ):
154- """
156+ """Return if a global svd projection basis was used.
157+
155158 :return: If a global svd was used for the BOP-DMD fit.
156159 :rtype: int or float
157160 """
158161 return self ._global_svd
159162
160163 @property
161164 def window_length (self ):
162- """
165+ """Return the window length used.
166+
163167 :return: the length of the windows used for this decomposition level.
164168 :rtype: int or float
165169 """
166170 return self ._window_length
167171
168172 @property
169173 def step_size (self ):
170- """
174+ """Return the step size between each window.
175+
171176 :return: the length of the windows used for this decomposition level.
172177 :rtype: int or float
173178 """
174179 return self ._step_size
175180
176181 @property
177182 def n_slides (self ):
178- """
183+ """Return the number of slides performed for each window.
184+
179185 :return: number of window slides for this decomposition level.
180186 :rtype: int
181187 """
182188 return self ._n_slides
183189
184190 @property
185191 def modes_array (self ):
186- """
192+ """Return the spatial modes of each window's fit.
193+
187194 :return: Modes for each window
188195 :rtype: numpy.ndarray
189196 """
190197 return self ._modes_array
191198
192199 @property
193200 def amplitudes_array (self ):
194- """
201+ """Return the amplitudes of each window's fit.
202+
195203 :return: amplitudes of each window
196204 :rtype: numpy.ndarray
197205 """
198206 return self ._amplitudes_array
199207
200208 @property
201209 def omega_array (self ):
202- """
210+ """Return the frequencies (omega) of each window's fit.
211+
203212 :return: omega (a.k.a eigenvalues or time dynamics) for each window
204213 :rtype: numpy.ndarray
205214 """
206215 return self ._omega_array
207216
208217 @property
209218 def time_array (self ):
210- """
219+ """Return the time values contained by each window.
220+
211221 :return: time values for each fit window
212222 :rtype: numpy.ndarray
213223 """
214224 return self ._time_array
215225
216226 @property
217227 def window_means_array (self ):
218- """
228+ """Return the array of window time means.
229+
219230 :return: Time mean of the data in each window
220231 :rtype: numpy.ndarray
221232 """
222233 return self ._window_means_array
223234
224235 @property
225236 def n_components (self ):
226- """
237+ """Return the number of frequency bands.
238+
227239 :return: Number of frequency bands fit in the kmeans clustering
228240 :rtype: int
229241 """
230242 return self ._n_components
231243
232244 @property
233245 def cluster_centroids (self ):
234- """
246+ """Return the frequency band centroids.
247+
235248 :return: Centroids of the frequency bands
236249 :rtype: numpy.ndarray
237250 """
238251 return self ._cluster_centroids
239252
240253 @property
241254 def omega_classes (self ):
242- """
255+ """Return the frequency band classifications.
256+
243257 :return: Frequency band classifications, corresponds to omega_array
244258 :rtype: numpy.ndarray
245259 """
246260 return self ._omega_classes
247261
248262 def periods (self ):
249- """
263+ """Convert the omega array into periods.
264+
250265 :return: Time dynamics converted to periods
251266 :rtype: numpy.ndarray
252267 """
@@ -259,7 +274,15 @@ def periods(self):
259274
260275 @staticmethod
261276 def relative_error (x_est , x_true ):
262- """Helper function for calculating the relative error."""
277+ """Helper function for calculating the relative error.
278+
279+ :param x_est: Estimated values (i.e. from reconstruction)
280+ :type x_est: numpy.ndarray
281+ :param x_true: True (or observed) values.
282+ :type x_true: numpy.ndarray
283+ :return: Relative error between observations and model.
284+ :rtype: numpy.ndarray
285+ """
263286 return np .linalg .norm (x_est - x_true ) / np .linalg .norm (x_true )
264287
265288 @staticmethod
@@ -346,13 +369,29 @@ def build_kern(window_length):
346369
347370 @staticmethod
348371 def _data_shape (data ):
349- """Returns the shape of the data."""
372+ """Returns the shape of the data.
373+
374+ :param data: Data to fit with mrCOSTS.
375+ :type data: numpy.ndarray
376+ :return n_time_steps: Number of time steps.
377+ :rtype n_time_steps: int
378+ :return n_data_vars: Number of spatial variables.
379+ :rtype n_data_vars: int
380+ """
350381 n_time_steps = np .shape (data )[1 ]
351382 n_data_vars = np .shape (data )[0 ]
352383 return n_time_steps , n_data_vars
353384
354385 def _build_proj_basis (self , data , svd_rank = None ):
355- """Build the projection basis."""
386+ """Build the projection basis.
387+
388+ :param data: Data to fit with mrCOSTS.
389+ :type data: numpy.ndarray
390+ :param svd_rank: Rank to fit with COSTS.
391+ :type svd_rank: int
392+ :return: SVD projection basis for COSTS.
393+ :rtype: numpy.ndarray
394+ """
356395 self ._svd_rank = compute_rank (data , svd_rank = svd_rank )
357396 # Recover the first r modes of the global svd
358397 return compute_svd (data , svd_rank = svd_rank )[0 ]
@@ -388,7 +427,8 @@ def _build_initizialization(self):
388427 and self ._cluster_centroids is not None
389428 ):
390429 raise ValueError (
391- "Only one of `init_alpha` and `cluster_centroids` can be provided"
430+ "Only one of `init_alpha` and `cluster_centroids` can be"
431+ " provided"
392432 )
393433 # If not initial values are provided return None by default.
394434 else :
@@ -589,8 +629,7 @@ def get_window_indices(self, k):
589629 sample_start = self ._step_size * k
590630 if k == self ._n_slides - 1 and self ._non_integer_n_slide :
591631 return slice (- self ._window_length , None )
592- else :
593- return slice (sample_start , sample_start + self ._window_length )
632+ return slice (sample_start , sample_start + self ._window_length )
594633
595634 def cluster_omega (
596635 self ,
@@ -826,10 +865,6 @@ def plot_omega_histogram(self):
826865 :return fig: Figure handle for the plot
827866 :return ax: Axes handle for the plot
828867 """
829- # Reshape the omega array into a 1d array
830- omega_array = self .omega_array
831- # n_slides = omega_array.shape[0]
832- # svd_rank = omega_array.shape[1]
833868
834869 # Apply the transformation to omega
835870 omega_transform = self .transform_omega (
@@ -865,12 +900,6 @@ def plot_omega_time_series(self):
865900 fig , ax = plt .subplots (1 , 1 )
866901 colors = plt .rcParams ["axes.prop_cycle" ].by_key ()["color" ]
867902
868- # Reshape the omega array into a 1d array
869- omega_array = self .omega_array
870- # n_slides = omega_array.shape[0]
871- # svd_rank = omega_array.shape[1]
872- # omega_rshp = omega_array.reshape(n_slides * svd_rank)
873-
874903 # Apply the transformation to omega
875904 omega_transform = self .transform_omega (
876905 self .omega_array .flatten (), transform_method = self ._transform_method
@@ -1360,7 +1389,10 @@ def plot_time_series(
13601389 for n in range (self .n_components ):
13611390 ax = axes [n + 1 ]
13621391 if n == 0 :
1363- title = "blue = Low-frequency component, black = high frequency residual"
1392+ title = (
1393+ "blue = Low-frequency component, black = high "
1394+ "frequency residual"
1395+ )
13641396 ax .plot (xr_sep [n , space_index , :] - ground_truth_mean )
13651397 else :
13661398 title = "Band period = {:.0f} window length" .format (
@@ -1386,7 +1418,10 @@ def plot_time_series(
13861418 label = "Residual" ,
13871419 )
13881420 ax .set_title (
1389- "black=input data, yellow=low-frequency, blue=high-frequency, red=residual"
1421+ (
1422+ "black=input data, yellow=low-frequency, "
1423+ "blue=high-frequency, red=residual"
1424+ )
13901425 )
13911426 else :
13921427 ax .set_title (
0 commit comments