Skip to content

Commit b8025a6

Browse files
committed
Refactor tests for spatialscores
1 parent fa9cb1c commit b8025a6

File tree

1 file changed

+21
-17
lines changed

1 file changed

+21
-17
lines changed

pysteps/tests/test_spatialscores.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,20 @@
66
from pysteps.tests.helpers import get_precipitation_fields
77
from pysteps.verification import spatialscores
88

9-
try:
10-
import pywt
11-
12-
PYWT_IMPORTED = True
13-
except ImportError:
14-
PYWT_IMPORTED = False
15-
169
R = get_precipitation_fields(num_prev_files=1, return_raw=True)
17-
test_data = [(R[0], R[1], "FSS", [1], [10], None, 0.85161531)]
18-
if PYWT_IMPORTED:
19-
test_data.append((R[0], R[1], "BMSE", [1], None, "Haar", 0.99989651))
20-
10+
test_data = [
11+
(R[0], R[1], "FSS", [1], [10], None, 0.85161531),
12+
(R[0], R[1], "BMSE", [1], None, "Haar", 0.99989651),
13+
]
2114

2215
@pytest.mark.parametrize(
2316
"X_f, X_o, name, thrs, scales, wavelet, expected", test_data
2417
)
2518
def test_intensity_scale(X_f, X_o, name, thrs, scales, wavelet, expected):
2619
"""Test the intensity_scale."""
20+
if name == "BMSE":
21+
pytest.importorskip("pywt")
22+
2723
assert_array_almost_equal(
2824
spatialscores.intensity_scale(X_f, X_o, name, thrs, scales, wavelet)[
2925
0
@@ -33,18 +29,26 @@ def test_intensity_scale(X_f, X_o, name, thrs, scales, wavelet, expected):
3329

3430

3531
R = get_precipitation_fields(num_prev_files=3, return_raw=True)
36-
test_data = [(R[:2], R[2:], "FSS", [1], [10], None, 0.85062658)]
37-
if PYWT_IMPORTED:
38-
test_data.append((R[:2], R[2:], "BMSE", [1], None, "Haar", 0.99985691))
39-
32+
test_data = [
33+
(R[:2], R[2:], "FSS", [1], [10], None),
34+
(R[:2], R[2:], "BMSE", [1], None, "Haar"),
35+
]
4036

4137
@pytest.mark.parametrize(
42-
"R1, R2, name, thrs, scales, wavelet, expected", test_data
38+
"R1, R2, name, thrs, scales, wavelet", test_data
4339
)
4440
def test_intensity_scale_methods(
45-
R1, R2, name, thrs, scales, wavelet, expected
41+
R1, R2, name, thrs, scales, wavelet
4642
):
4743
"""Test the intensity_scale merge."""
44+
if name == "BMSE":
45+
pytest.importorskip("pywt")
46+
47+
# expected reult
48+
int = spatialscores.intensity_scale_init(name, thrs, scales, wavelet)
49+
spatialscores.intensity_scale_accum(int, R1[0], R1[1])
50+
spatialscores.intensity_scale_accum(int, R2[0], R2[1])
51+
expected = spatialscores.intensity_scale_compute(int)[0][0]
4852

4953
# init
5054
int_1 = spatialscores.intensity_scale_init(name, thrs, scales, wavelet)

0 commit comments

Comments
 (0)