We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0326b1a commit 0853ac7Copy full SHA for 0853ac7
mkl_fft/tests/test_interfaces.py
@@ -29,14 +29,6 @@
29
import numpy as np
30
31
32
-def test_interfaces_has_numpy():
33
- assert hasattr(mfi, 'numpy_fft')
34
-
35
36
-def test_interfaces_has_scipy():
37
- assert hasattr(mfi, 'scipy_fft')
38
39
40
@pytest.mark.parametrize('norm', [None, "forward", "backward", "ortho"])
41
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.complex64, np.complex128])
42
def test_scipy_fft(norm, dtype):
@@ -151,3 +143,15 @@ def test_scipy_fft_arg_validate():
151
143
with pytest.raises(NotImplementedError):
152
144
mfi.scipy_fft.fft([1,2,3,4], plan="magic")
153
145
146
+
147
+@pytest.mark.parametrize(
148
+ "func",
149
+ [mfi.scipy_fft.rfft2, mfi.numpy_fft.rfft2],
150
+ ids=["scipy", "numpy"],
+)
+def test_axes(func):
+ x = np.arange(24.).reshape(2, 3, 4)
154
+ res = func(x, axes=(1, 2))
155
+ exp = np.fft.rfft2(x, axes=(1, 2))
156
+ tol = 64 * np.finfo(np.float64).eps
157
+ assert np.allclose(res, exp, atol=tol, rtol=tol)
0 commit comments