Skip to content

Commit 28cc528

Browse files
committed
test: add safeguards to dtcwt and spgl1 for numpy2
1 parent 3f9b941 commit 28cc528

File tree

1 file changed

+42
-38
lines changed

1 file changed

+42
-38
lines changed

pytests/test_dtcwt.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
from pylops.signalprocessing import DTCWT
55

6+
# currently test only if numpy<2.0.0 is installed...
7+
np_version = np.__version__.split(".")
8+
69
par1 = {"ny": 10, "nx": 10, "dtype": "float64"}
710
par2 = {"ny": 50, "nx": 50, "dtype": "float64"}
811

@@ -17,66 +20,67 @@ def sequential_array(shape):
1720
@pytest.mark.parametrize("par", [(par1), (par2)])
1821
def test_dtcwt1D_input1D(par):
1922
"""Test for DTCWT with 1D input"""
23+
if int(np_version[0]) < 2:
24+
t = sequential_array((par["ny"],))
2025

21-
t = sequential_array((par["ny"],))
22-
23-
for level in range(1, 10):
24-
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
25-
x = Dtcwt @ t
26-
y = Dtcwt.H @ x
26+
for level in range(1, 10):
27+
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
28+
x = Dtcwt @ t
29+
y = Dtcwt.H @ x
2730

28-
np.testing.assert_allclose(t, y)
31+
np.testing.assert_allclose(t, y)
2932

3033

3134
@pytest.mark.parametrize("par", [(par1), (par2)])
3235
def test_dtcwt1D_input2D(par):
3336
"""Test for DTCWT with 2D input (forward-inverse pair)"""
34-
35-
t = sequential_array(
36-
(
37-
par["ny"],
38-
par["ny"],
37+
if int(np_version[0]) < 2:
38+
t = sequential_array(
39+
(
40+
par["ny"],
41+
par["ny"],
42+
)
3943
)
40-
)
4144

42-
for level in range(1, 10):
43-
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
44-
x = Dtcwt @ t
45-
y = Dtcwt.H @ x
45+
for level in range(1, 10):
46+
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
47+
x = Dtcwt @ t
48+
y = Dtcwt.H @ x
4649

47-
np.testing.assert_allclose(t, y)
50+
np.testing.assert_allclose(t, y)
4851

4952

5053
@pytest.mark.parametrize("par", [(par1), (par2)])
5154
def test_dtcwt1D_input3D(par):
5255
"""Test for DTCWT with 3D input (forward-inverse pair)"""
56+
if int(np_version[0]) < 2:
57+
t = sequential_array((par["ny"], par["ny"], par["ny"]))
5358

54-
t = sequential_array((par["ny"], par["ny"], par["ny"]))
59+
for level in range(1, 10):
60+
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
61+
x = Dtcwt @ t
62+
y = Dtcwt.H @ x
5563

56-
for level in range(1, 10):
57-
Dtcwt = DTCWT(dims=t.shape, level=level, dtype=par["dtype"])
58-
x = Dtcwt @ t
59-
y = Dtcwt.H @ x
60-
61-
np.testing.assert_allclose(t, y)
64+
np.testing.assert_allclose(t, y)
6265

6366

6467
@pytest.mark.parametrize("par", [(par1), (par2)])
6568
def test_dtcwt1D_birot(par):
6669
"""Test for DTCWT birot (forward-inverse pair)"""
67-
birots = ["antonini", "legall", "near_sym_a", "near_sym_b"]
68-
69-
t = sequential_array(
70-
(
71-
par["ny"],
72-
par["ny"],
70+
if int(np_version[0]) < 2:
71+
birots = ["antonini", "legall", "near_sym_a", "near_sym_b"]
72+
73+
t = sequential_array(
74+
(
75+
par["ny"],
76+
par["ny"],
77+
)
7378
)
74-
)
7579

76-
for _b in birots:
77-
print(f"birot {_b}")
78-
Dtcwt = DTCWT(dims=t.shape, biort=_b, dtype=par["dtype"])
79-
x = Dtcwt @ t
80-
y = Dtcwt.H @ x
80+
for _b in birots:
81+
print(f"birot {_b}")
82+
Dtcwt = DTCWT(dims=t.shape, biort=_b, dtype=par["dtype"])
83+
x = Dtcwt @ t
84+
y = Dtcwt.H @ x
8185

82-
np.testing.assert_allclose(t, y)
86+
np.testing.assert_allclose(t, y)

0 commit comments

Comments
 (0)