Skip to content

Commit b9618f5

Browse files
committed
chore: simplify testing logic
1 parent 45b5fbc commit b9618f5

File tree

1 file changed

+47
-39
lines changed

1 file changed

+47
-39
lines changed

pytests/test_dtcwt.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,67 +20,75 @@ def sequential_array(shape):
2020
@pytest.mark.parametrize("par", [(par1), (par2)])
2121
def test_dtcwt1D_input1D(par):
2222
"""Test for DTCWT with 1D input"""
23-
if int(np_version[0]) < 2:
24-
t = sequential_array((par["ny"],))
23+
if int(np_version[0]) >= 2:
24+
return
2525

26-
for level in range(1, 10):
27-
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
28-
x = Dtcwt @ t
29-
y = Dtcwt.H @ x
26+
t = sequential_array((par["ny"],))
3027

31-
np.testing.assert_allclose(t, y)
28+
for level in range(1, 10):
29+
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
30+
x = Dtcwt @ t
31+
y = Dtcwt.H @ x
32+
33+
np.testing.assert_allclose(t, y)
3234

3335

3436
@pytest.mark.parametrize("par", [(par1), (par2)])
3537
def test_dtcwt1D_input2D(par):
3638
"""Test for DTCWT with 2D input (forward-inverse pair)"""
37-
if int(np_version[0]) < 2:
38-
t = sequential_array(
39-
(
40-
par["ny"],
41-
par["ny"],
42-
)
39+
if int(np_version[0]) >= 2:
40+
return
41+
42+
t = sequential_array(
43+
(
44+
par["ny"],
45+
par["ny"],
4346
)
47+
)
4448

45-
for level in range(1, 10):
46-
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
47-
x = Dtcwt @ t
48-
y = Dtcwt.H @ x
49+
for level in range(1, 10):
50+
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
51+
x = Dtcwt @ t
52+
y = Dtcwt.H @ x
4953

50-
np.testing.assert_allclose(t, y)
54+
np.testing.assert_allclose(t, y)
5155

5256

5357
@pytest.mark.parametrize("par", [(par1), (par2)])
5458
def test_dtcwt1D_input3D(par):
5559
"""Test for DTCWT with 3D input (forward-inverse pair)"""
56-
if int(np_version[0]) < 2:
57-
t = sequential_array((par["ny"], par["ny"], par["ny"]))
60+
if int(np_version[0]) >= 2:
61+
return
62+
63+
t = sequential_array((par["ny"], par["ny"], par["ny"]))
5864

59-
for level in range(1, 10):
60-
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
61-
x = Dtcwt @ t
62-
y = Dtcwt.H @ x
65+
for level in range(1, 10):
66+
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
67+
x = Dtcwt @ t
68+
y = Dtcwt.H @ x
6369

64-
np.testing.assert_allclose(t, y)
70+
np.testing.assert_allclose(t, y)
6571

6672

6773
@pytest.mark.parametrize("par", [(par1), (par2)])
6874
def test_dtcwt1D_birot(par):
6975
"""Test for DTCWT birot (forward-inverse pair)"""
70-
if int(np_version[0]) < 2:
71-
birots = ["antonini", "legall", "near_sym_a", "near_sym_b"]
72-
73-
t = sequential_array(
74-
(
75-
par["ny"],
76-
par["ny"],
77-
)
76+
if int(np_version[0]) >= 2:
77+
return
78+
79+
birots = ["antonini", "legall", "near_sym_a", "near_sym_b"]
80+
81+
t = sequential_array(
82+
(
83+
par["ny"],
84+
par["ny"],
7885
)
86+
)
7987

80-
for _b in birots:
81-
print(f"birot {_b}")
82-
Dtcwt = DTCWT(dims=t.shape, biort=_b, dtype=par["dtype"])
83-
x = Dtcwt @ t
84-
y = Dtcwt.H @ x
88+
for _b in birots:
89+
print(f"birot {_b}")
90+
Dtcwt = DTCWT(dims=t.shape, biort=_b, dtype=par["dtype"])
91+
x = Dtcwt @ t
92+
y = Dtcwt.H @ x
8593

86-
np.testing.assert_allclose(t, y)
94+
np.testing.assert_allclose(t, y)

0 commit comments

Comments
 (0)