26
26
27
27
import mkl_fft .interfaces as mfi
28
28
import pytest
29
+ import numpy as np
29
30
30
31
31
32
def test_interfaces_has_numpy ():
@@ -34,3 +35,43 @@ def test_interfaces_has_numpy():
34
35
35
36
def test_interfaces_has_scipy ():
36
37
assert hasattr (mfi , 'scipy_fft' )
38
+
39
+
40
+ @pytest .mark .parametrize ('norm' , [None , "forward" , "backward" , "ortho" ])
41
+ @pytest .mark .parametrize ('dtype' , [np .float32 , np .float64 , np .complex64 , np .complex128 ])
42
+ def test_scipy_fft (norm , dtype ):
43
+ x = np .ones (511 , dtype = dtype )
44
+ w = mfi .scipy_fft .fft (x , norm = norm )
45
+ xx = mfi .scipy_fft .ifft (w , norm = norm )
46
+ tol = 64 * np .finfo (np .dtype (dtype )).eps
47
+ assert np .allclose (x , xx , atol = tol , rtol = tol )
48
+
49
+
50
+ @pytest .mark .parametrize ('norm' , [None , "forward" , "backward" , "ortho" ])
51
+ @pytest .mark .parametrize ('dtype' , [np .float32 , np .float64 ])
52
+ def test_scipy_rfft (norm , dtype ):
53
+ x = np .ones (511 , dtype = dtype )
54
+ w = mfi .scipy_fft .rfft (x , norm = norm )
55
+ xx = mfi .scipy_fft .irfft (w , n = x .shape [0 ], norm = norm )
56
+ tol = 64 * np .finfo (np .dtype (dtype )).eps
57
+ assert np .allclose (x , xx , atol = tol , rtol = tol )
58
+
59
+
60
+ @pytest .mark .parametrize ('norm' , [None , "forward" , "backward" , "ortho" ])
61
+ @pytest .mark .parametrize ('dtype' , [np .float32 , np .float64 , np .complex64 , np .complex128 ])
62
+ def test_scipy_fftn (norm , dtype ):
63
+ x = np .ones ((37 , 83 ), dtype = dtype )
64
+ w = mfi .scipy_fft .fftn (x , norm = norm )
65
+ xx = mfi .scipy_fft .ifftn (w , norm = norm )
66
+ tol = 64 * np .finfo (np .dtype (dtype )).eps
67
+ assert np .allclose (x , xx , atol = tol , rtol = tol )
68
+
69
+
70
+ @pytest .mark .parametrize ('norm' , [None , "forward" , "backward" , "ortho" ])
71
+ @pytest .mark .parametrize ('dtype' , [np .float32 , np .float64 ])
72
+ def test_scipy_rftn (norm , dtype ):
73
+ x = np .ones ((37 , 83 ), dtype = dtype )
74
+ w = mfi .scipy_fft .rfftn (x , norm = norm )
75
+ xx = mfi .scipy_fft .ifftn (w , s = x .shape , norm = norm )
76
+ tol = 64 * np .finfo (np .dtype (dtype )).eps
77
+ assert np .allclose (x , xx , atol = tol , rtol = tol )
0 commit comments