@@ -21,62 +21,66 @@ def test_TorchOperator(par):
2121 """
2222 # temporarily, skip tests on mac as torch seems not to recognized
2323 # numpy when v2 is installed
24- if platform .system () != "Darwin" :
25- Dop = MatrixMult (np .random .normal (0.0 , 1.0 , (par ["ny" ], par ["nx" ])))
26- Top = TorchOperator (Dop , batch = False )
24+ if platform .system () == "Darwin" :
25+ return
2726
28- x = np .random .normal (0.0 , 1.0 , par ["nx" ])
29- xt = torch .from_numpy (x ).view (- 1 )
30- xt .requires_grad = True
31- v = torch .randn (par ["ny" ])
27+ Dop = MatrixMult (np .random .normal (0.0 , 1.0 , (par ["ny" ], par ["nx" ])))
28+ Top = TorchOperator (Dop , batch = False )
3229
33- # pylops operator
34- y = Dop * x
35- xadj = Dop .H * v
30+ x = np .random .normal (0.0 , 1.0 , par ["nx" ])
31+ xt = torch .from_numpy (x ).view (- 1 )
32+ xt .requires_grad = True
33+ v = torch .randn (par ["ny" ])
3634
37- # torch operator
38- yt = Top . apply ( xt )
39- yt . backward ( v , retain_graph = True )
35+ # pylops operator
36+ y = Dop * x
37+ xadj = Dop . H * v
4038
41- assert_array_equal (y , yt .detach ().cpu ().numpy ())
42- assert_array_equal (xadj , xt .grad .cpu ().numpy ())
39+ # torch operator
40+ yt = Top .apply (xt )
41+ yt .backward (v , retain_graph = True )
42+
43+ assert_array_equal (y , yt .detach ().cpu ().numpy ())
44+ assert_array_equal (xadj , xt .grad .cpu ().numpy ())
4345
4446
4547@pytest .mark .parametrize ("par" , [(par1 )])
4648def test_TorchOperator_batch (par ):
4749 """Apply forward for input with multiple samples (= batch) and flattened arrays"""
4850 # temporarily, skip tests on mac as torch seems not to recognized
4951 # numpy when v2 is installed
50- if platform .system () != "Darwin" :
51- Dop = MatrixMult (np .random .normal (0.0 , 1.0 , (par ["ny" ], par ["nx" ])))
52- Top = TorchOperator (Dop , batch = True )
52+ if platform .system () == "Darwin" :
53+ return
54+
55+ Dop = MatrixMult (np .random .normal (0.0 , 1.0 , (par ["ny" ], par ["nx" ])))
56+ Top = TorchOperator (Dop , batch = True )
5357
54- x = np .random .normal (0.0 , 1.0 , (4 , par ["nx" ]))
55- xt = torch .from_numpy (x )
56- xt .requires_grad = True
58+ x = np .random .normal (0.0 , 1.0 , (4 , par ["nx" ]))
59+ xt = torch .from_numpy (x )
60+ xt .requires_grad = True
5761
58- y = Dop .matmat (x .T ).T
59- yt = Top .apply (xt )
62+ y = Dop .matmat (x .T ).T
63+ yt = Top .apply (xt )
6064
61- assert_array_equal (y , yt .detach ().cpu ().numpy ())
65+ assert_array_equal (y , yt .detach ().cpu ().numpy ())
6266
6367
6468@pytest .mark .parametrize ("par" , [(par1 )])
6569def test_TorchOperator_batch_nd (par ):
6670 """Apply forward for input with multiple samples (= batch) and nd-arrays"""
6771 # temporarily, skip tests on mac as torch seems not to recognized
6872 # numpy when v2 is installed
69- if platform .system () ! = "Darwin" :
70- Dop = MatrixMult (
71- np . random . normal ( 0.0 , 1.0 , ( par [ "ny" ], par [ "nx" ])), otherdims = ( 2 ,)
72- )
73- Top = TorchOperator (Dop , batch = True , flatten = False )
73+ if platform .system () = = "Darwin" :
74+ return
75+
76+ Dop = MatrixMult ( np . random . normal ( 0.0 , 1.0 , ( par [ "ny" ], par [ "nx" ])), otherdims = ( 2 ,) )
77+ Top = TorchOperator (Dop , batch = True , flatten = False )
7478
75- x = np .random .normal (0.0 , 1.0 , (4 , par ["nx" ], 2 ))
76- xt = torch .from_numpy (x )
77- xt .requires_grad = True
79+ x = np .random .normal (0.0 , 1.0 , (4 , par ["nx" ], 2 ))
80+ xt = torch .from_numpy (x )
81+ xt .requires_grad = True
7882
79- y = (Dop @ x .transpose (1 , 2 , 0 )).transpose (2 , 0 , 1 )
80- yt = Top .apply (xt )
83+ y = (Dop @ x .transpose (1 , 2 , 0 )).transpose (2 , 0 , 1 )
84+ yt = Top .apply (xt )
8185
82- assert_array_equal (y , yt .detach ().cpu ().numpy ())
86+ assert_array_equal (y , yt .detach ().cpu ().numpy ())
0 commit comments