29
29
30
30
import mkl_fft .interfaces as mfi
31
31
32
+ try :
33
+ scipy_fft = mfi .scipy_fft
34
+ except AttributeError :
35
+ scipy_fft = None
36
+
37
+ interfaces = []
38
+ ids = []
39
+ if scipy_fft is not None :
40
+ interfaces .append (scipy_fft )
41
+ ids .append ("scipy" )
42
+ interfaces .append (mfi .numpy_fft )
43
+ ids .append ("numpy" )
44
+
32
45
33
46
@pytest .mark .parametrize ("norm" , [None , "forward" , "backward" , "ortho" ])
34
47
@pytest .mark .parametrize (
35
48
"dtype" , [np .float32 , np .float64 , np .complex64 , np .complex128 ]
36
49
)
37
50
def test_scipy_fft (norm , dtype ):
51
+ pytest .importorskip ("scipy" , reason = "requires scipy" )
38
52
x = np .ones (511 , dtype = dtype )
39
53
w = mfi .scipy_fft .fft (x , norm = norm , workers = None , plan = None )
40
54
xx = mfi .scipy_fft .ifft (w , norm = norm , workers = None , plan = None )
@@ -57,6 +71,7 @@ def test_numpy_fft(norm, dtype):
57
71
@pytest .mark .parametrize ("norm" , [None , "forward" , "backward" , "ortho" ])
58
72
@pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
59
73
def test_scipy_rfft (norm , dtype ):
74
+ pytest .importorskip ("scipy" , reason = "requires scipy" )
60
75
x = np .ones (511 , dtype = dtype )
61
76
w = mfi .scipy_fft .rfft (x , norm = norm , workers = None , plan = None )
62
77
xx = mfi .scipy_fft .irfft (
@@ -87,6 +102,7 @@ def test_numpy_rfft(norm, dtype):
87
102
"dtype" , [np .float32 , np .float64 , np .complex64 , np .complex128 ]
88
103
)
89
104
def test_scipy_fftn (norm , dtype ):
105
+ pytest .importorskip ("scipy" , reason = "requires scipy" )
90
106
x = np .ones ((37 , 83 ), dtype = dtype )
91
107
w = mfi .scipy_fft .fftn (x , norm = norm , workers = None , plan = None )
92
108
xx = mfi .scipy_fft .ifftn (w , norm = norm , workers = None , plan = None )
@@ -109,6 +125,7 @@ def test_numpy_fftn(norm, dtype):
109
125
@pytest .mark .parametrize ("norm" , [None , "forward" , "backward" , "ortho" ])
110
126
@pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
111
127
def test_scipy_rfftn (norm , dtype ):
128
+ pytest .importorskip ("scipy" , reason = "requires scipy" )
112
129
x = np .ones ((37 , 83 ), dtype = dtype )
113
130
w = mfi .scipy_fft .rfftn (x , norm = norm , workers = None , plan = None )
114
131
xx = mfi .scipy_fft .irfftn (w , s = x .shape , norm = norm , workers = None , plan = None )
@@ -143,32 +160,30 @@ def _get_blacklisted_dtypes():
143
160
144
161
@pytest .mark .parametrize ("dtype" , _get_blacklisted_dtypes ())
145
162
def test_scipy_no_support_for (dtype ):
163
+ pytest .importorskip ("scipy" , reason = "requires scipy" )
146
164
x = np .ones (16 , dtype = dtype )
147
165
assert_raises (NotImplementedError , mfi .scipy_fft .ifft , x )
148
166
149
167
150
168
def test_scipy_fft_arg_validate ():
169
+ pytest .importorskip ("scipy" , reason = "requires scipy" )
151
170
with pytest .raises (ValueError ):
152
171
mfi .scipy_fft .fft ([1 , 2 , 3 , 4 ], norm = b"invalid" )
153
172
154
173
with pytest .raises (NotImplementedError ):
155
174
mfi .scipy_fft .fft ([1 , 2 , 3 , 4 ], plan = "magic" )
156
175
157
176
158
- @pytest .mark .parametrize (
159
- "func" , [mfi .scipy_fft .rfft2 , mfi .numpy_fft .rfft2 ], ids = ["scipy" , "numpy" ]
160
- )
161
- def test_axes (func ):
177
+ @pytest .mark .parametrize ("interface" , interfaces , ids = ids )
178
+ def test_axes (interface ):
162
179
x = np .arange (24.0 ).reshape (2 , 3 , 4 )
163
- res = func (x , axes = (1 , 2 ))
180
+ res = interface . rfft2 (x , axes = (1 , 2 ))
164
181
exp = np .fft .rfft2 (x , axes = (1 , 2 ))
165
182
tol = 64 * np .finfo (np .float64 ).eps
166
183
assert np .allclose (res , exp , atol = tol , rtol = tol )
167
184
168
185
169
- @pytest .mark .parametrize (
170
- "interface" , [mfi .scipy_fft , mfi .numpy_fft ], ids = ["scipy" , "numpy" ]
171
- )
186
+ @pytest .mark .parametrize ("interface" , interfaces , ids = ids )
172
187
@pytest .mark .parametrize (
173
188
"func" , ["fftshift" , "ifftshift" , "fftfreq" , "rfftfreq" ]
174
189
)
0 commit comments