Skip to content

Commit d8e29c0

Browse files
klapomtezzele
authored andcommitted
Fixed bug in building convolution windows
- Convolution windows had MATLAB style indexing and not python. As a result the windows did not peak at the right place. - Similarly, the window rounding was assymetric. - Causes slight reduction in the reconstructed error in the toy data examples. - `b` amplitudes now fall along PSD of the input data, fixing a lingering issue in the mrCOSTS fit.
1 parent 14c8b34 commit d8e29c0

File tree

3 files changed

+8
-16
lines changed

3 files changed

+8
-16
lines changed

pydmd/costs.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,11 @@ def calculate_lv_kern(window_length, corner_sharpness=None):
325325

326326
lv_kern = (
327327
np.tanh(
328-
corner_sharpness
329-
* np.arange(1, window_length + 1)
330-
/ window_length
328+
corner_sharpness * np.arange(0, window_length) / window_length
331329
)
332330
- np.tanh(
333331
corner_sharpness
334-
* (np.arange(1, window_length + 1) - window_length)
332+
* (np.arange(0, window_length) - window_length - 1)
335333
/ window_length
336334
)
337335
- 1
@@ -354,7 +352,7 @@ def build_kern(window_length):
354352
"""
355353
recon_filter_sd = window_length / 8
356354
recon_filter = np.exp(
357-
-((np.arange(window_length) - (window_length + 1) / 2) ** 2)
355+
-((np.arange(window_length) - (window_length - 1) / 2) ** 2)
358356
/ recon_filter_sd**2
359357
)
360358
return recon_filter

pydmd/mrcosts.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,14 +1041,8 @@ def global_scale_reconstruction(
10411041

10421042
# Convolve each windowed reconstruction with a gaussian filter.
10431043
# Std dev of gaussian filter
1044-
recon_filter_sd = mrd.window_length / 8
1045-
recon_filter = np.exp(
1046-
-(
1047-
(np.arange(mrd.window_length) - (mrd.window_length + 1) / 2)
1048-
** 2
1049-
)
1050-
/ recon_filter_sd**2
1051-
)
1044+
recon_filter = mrd.build_kern(mrd.window_length)
1045+
10521046
omega_classes = omega_classes_list[n_mrd]
10531047

10541048
if mrd.svd_rank < np.max(self._svd_rank_array):

tests/test_mrcosts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ def rhs_UFD(t, y, eta, epsilon, tau):
149149

150150
# Define the expected error in the reconstructions.
151151
expected_global_error = 0.053
152-
expected_lf_error = 0.12
153-
expected_hf_error = 0.19
152+
expected_lf_error = 0.10
153+
expected_hf_error = 0.17
154154
expected_transient_error = 0.3
155155

156156
# Fit mrCOSTS for testing
@@ -360,7 +360,7 @@ def test_plot_local_time_series():
360360
_ = mrc.plot_local_time_series(0, 0, data=data.T)
361361

362362

363-
def tear_down():
363+
def test_tear_down():
364364
"""Remove the files generated in `test_netcdf`"""
365365
file_list = glob.glob("*tests*.nc")
366366
for f in file_list:

0 commit comments

Comments
 (0)