Skip to content

Commit f8e3323

Browse files
committed
chore: simplify testing logic
1 parent 734776d commit f8e3323

File tree

1 file changed

+39
-35
lines changed

1 file changed

+39
-35
lines changed

pytests/test_torchoperator.py

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,62 +21,66 @@ def test_TorchOperator(par):
2121
"""
2222
# temporarily, skip tests on mac as torch seems not to recognized
2323
# numpy when v2 is installed
24-
if platform.system() != "Darwin":
25-
Dop = MatrixMult(np.random.normal(0.0, 1.0, (par["ny"], par["nx"])))
26-
Top = TorchOperator(Dop, batch=False)
24+
if platform.system() == "Darwin":
25+
return
2726

28-
x = np.random.normal(0.0, 1.0, par["nx"])
29-
xt = torch.from_numpy(x).view(-1)
30-
xt.requires_grad = True
31-
v = torch.randn(par["ny"])
27+
Dop = MatrixMult(np.random.normal(0.0, 1.0, (par["ny"], par["nx"])))
28+
Top = TorchOperator(Dop, batch=False)
3229

33-
# pylops operator
34-
y = Dop * x
35-
xadj = Dop.H * v
30+
x = np.random.normal(0.0, 1.0, par["nx"])
31+
xt = torch.from_numpy(x).view(-1)
32+
xt.requires_grad = True
33+
v = torch.randn(par["ny"])
3634

37-
# torch operator
38-
yt = Top.apply(xt)
39-
yt.backward(v, retain_graph=True)
35+
# pylops operator
36+
y = Dop * x
37+
xadj = Dop.H * v
4038

41-
assert_array_equal(y, yt.detach().cpu().numpy())
42-
assert_array_equal(xadj, xt.grad.cpu().numpy())
39+
# torch operator
40+
yt = Top.apply(xt)
41+
yt.backward(v, retain_graph=True)
42+
43+
assert_array_equal(y, yt.detach().cpu().numpy())
44+
assert_array_equal(xadj, xt.grad.cpu().numpy())
4345

4446

4547
@pytest.mark.parametrize("par", [(par1)])
4648
def test_TorchOperator_batch(par):
4749
"""Apply forward for input with multiple samples (= batch) and flattened arrays"""
4850
# temporarily, skip tests on mac as torch seems not to recognized
4951
# numpy when v2 is installed
50-
if platform.system() != "Darwin":
51-
Dop = MatrixMult(np.random.normal(0.0, 1.0, (par["ny"], par["nx"])))
52-
Top = TorchOperator(Dop, batch=True)
52+
if platform.system() == "Darwin":
53+
return
54+
55+
Dop = MatrixMult(np.random.normal(0.0, 1.0, (par["ny"], par["nx"])))
56+
Top = TorchOperator(Dop, batch=True)
5357

54-
x = np.random.normal(0.0, 1.0, (4, par["nx"]))
55-
xt = torch.from_numpy(x)
56-
xt.requires_grad = True
58+
x = np.random.normal(0.0, 1.0, (4, par["nx"]))
59+
xt = torch.from_numpy(x)
60+
xt.requires_grad = True
5761

58-
y = Dop.matmat(x.T).T
59-
yt = Top.apply(xt)
62+
y = Dop.matmat(x.T).T
63+
yt = Top.apply(xt)
6064

61-
assert_array_equal(y, yt.detach().cpu().numpy())
65+
assert_array_equal(y, yt.detach().cpu().numpy())
6266

6367

6468
@pytest.mark.parametrize("par", [(par1)])
6569
def test_TorchOperator_batch_nd(par):
6670
"""Apply forward for input with multiple samples (= batch) and nd-arrays"""
6771
# temporarily, skip tests on mac as torch seems not to recognized
6872
# numpy when v2 is installed
69-
if platform.system() != "Darwin":
70-
Dop = MatrixMult(
71-
np.random.normal(0.0, 1.0, (par["ny"], par["nx"])), otherdims=(2,)
72-
)
73-
Top = TorchOperator(Dop, batch=True, flatten=False)
73+
if platform.system() == "Darwin":
74+
return
75+
76+
Dop = MatrixMult(np.random.normal(0.0, 1.0, (par["ny"], par["nx"])), otherdims=(2,))
77+
Top = TorchOperator(Dop, batch=True, flatten=False)
7478

75-
x = np.random.normal(0.0, 1.0, (4, par["nx"], 2))
76-
xt = torch.from_numpy(x)
77-
xt.requires_grad = True
79+
x = np.random.normal(0.0, 1.0, (4, par["nx"], 2))
80+
xt = torch.from_numpy(x)
81+
xt.requires_grad = True
7882

79-
y = (Dop @ x.transpose(1, 2, 0)).transpose(2, 0, 1)
80-
yt = Top.apply(xt)
83+
y = (Dop @ x.transpose(1, 2, 0)).transpose(2, 0, 1)
84+
yt = Top.apply(xt)
8185

82-
assert_array_equal(y, yt.detach().cpu().numpy())
86+
assert_array_equal(y, yt.detach().cpu().numpy())

0 commit comments

Comments
 (0)