Skip to content

Commit 9273b25

Browse files
authored
Merge pull request PyDMD#532 from klapo/mrcosts-indexing-cleanup
Fixed bug in building convolution windows for mrCOSTS
2 parents 14c8b34 + f7e527b commit 9273b25

File tree

4 files changed

+28
-33
lines changed

4 files changed

+28
-33
lines changed

pydmd/costs.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,12 +326,11 @@ def calculate_lv_kern(window_length, corner_sharpness=None):
326326
lv_kern = (
327327
np.tanh(
328328
corner_sharpness
329-
* np.arange(1, window_length + 1)
330-
/ window_length
329+
* np.linspace(0, 1, window_length, endpoint=False)
331330
)
332331
- np.tanh(
333332
corner_sharpness
334-
* (np.arange(1, window_length + 1) - window_length)
333+
* (np.arange(0, window_length) - window_length - 1)
335334
/ window_length
336335
)
337336
- 1
@@ -354,7 +353,7 @@ def build_kern(window_length):
354353
"""
355354
recon_filter_sd = window_length / 8
356355
recon_filter = np.exp(
357-
-((np.arange(window_length) - (window_length + 1) / 2) ** 2)
356+
-((np.arange(window_length) - (window_length - 1) / 2) ** 2)
358357
/ recon_filter_sd**2
359358
)
360359
return recon_filter

pydmd/meta.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
__project__ = "PyDMD"
1414
__title__ = "pydmd"
15-
__author__ = "Nicola Demo, Marco Tezzele, Francesco Andreuzzi, Sara Ichinaga, Karl Lapo"
15+
__author__ = (
16+
"Nicola Demo, Marco Tezzele, Francesco Andreuzzi, Sara Ichinaga, Karl Lapo"
17+
)
1618
__copyright__ = "Copyright 2017-2024, PyDMD contributors"
1719
__license__ = "MIT"
1820
__version__ = "1.0.0"

pydmd/mrcosts.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ def from_netcdf(self, file_list):
560560
self._n_data_vars = n_data_vars
561561
self._n_time_steps = n_time_steps
562562

563-
def to_netcdf(self, filename):
563+
def to_netcdf(self, filename, filepath="."):
564564
"""
565565
Save the mrCoSTS fit to file in netcdf format.
566566
@@ -569,16 +569,22 @@ def to_netcdf(self, filename):
569569
570570
:param filename: Common name shared by each file.
571571
:type filename: str
572+
:param filepath: Path to save the results. Default is the current
573+
directory.
574+
:type filename: str
575+
572576
"""
573577
for c in self._costs_array:
578+
fname = ".".join(
579+
(
580+
filename,
581+
f"window={c.window_length:}",
582+
"nc",
583+
)
584+
)
585+
fpath = os.path.join(filepath, fname)
574586
c.to_xarray().to_netcdf(
575-
".".join(
576-
(
577-
filename,
578-
f"window={c.window_length:}",
579-
"nc",
580-
)
581-
),
587+
fpath,
582588
engine="h5netcdf",
583589
invalid_netcdf=True,
584590
)
@@ -1041,14 +1047,8 @@ def global_scale_reconstruction(
10411047

10421048
# Convolve each windowed reconstruction with a gaussian filter.
10431049
# Std dev of gaussian filter
1044-
recon_filter_sd = mrd.window_length / 8
1045-
recon_filter = np.exp(
1046-
-(
1047-
(np.arange(mrd.window_length) - (mrd.window_length + 1) / 2)
1048-
** 2
1049-
)
1050-
/ recon_filter_sd**2
1051-
)
1050+
recon_filter = mrd.build_kern(mrd.window_length)
1051+
10521052
omega_classes = omega_classes_list[n_mrd]
10531053

10541054
if mrd.svd_rank < np.max(self._svd_rank_array):

tests/test_mrcosts.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ def rhs_UFD(t, y, eta, epsilon, tau):
149149

150150
# Define the expected error in the reconstructions.
151151
expected_global_error = 0.053
152-
expected_lf_error = 0.12
153-
expected_hf_error = 0.19
152+
expected_lf_error = 0.10
153+
expected_hf_error = 0.17
154154
expected_transient_error = 0.3
155155

156156
# Fit mrCOSTS for testing
@@ -286,13 +286,14 @@ def test_omega_transforms():
286286
)
287287

288288

289-
def test_netcdf():
289+
def test_netcdf(tmp_path):
290290
"""
291291
Test the round trip conversion of the mrCOSTS object to file in
292292
netcdf format and back to mrCOSTS.
293293
"""
294-
mrc.to_netcdf("tests")
295-
file_list = glob.glob("*tests*.nc")
294+
# Move the I/O tests to the temporary test directory.
295+
mrc.to_netcdf("tests", filepath=tmp_path)
296+
file_list = glob.glob(os.path.join(tmp_path, "*tests*.nc"))
296297
mrc_from_file = mrCOSTS()
297298
mrc_from_file.from_netcdf(file_list)
298299

@@ -358,10 +359,3 @@ def test_plot_local_time_series():
358359

359360
with raises(ValueError):
360361
_ = mrc.plot_local_time_series(0, 0, data=data.T)
361-
362-
363-
def tear_down():
364-
"""Remove the files generated in `test_netcdf`"""
365-
file_list = glob.glob("*tests*.nc")
366-
for f in file_list:
367-
os.remove(f)

0 commit comments

Comments
 (0)