Skip to content

Commit 0853ac7

Browse files
author
Vahid Tavanashad
committed
add a new test
1 parent 0326b1a commit 0853ac7

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

mkl_fft/tests/test_interfaces.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,6 @@
2929
import numpy as np
3030

3131

32-
def test_interfaces_has_numpy():
33-
assert hasattr(mfi, 'numpy_fft')
34-
35-
36-
def test_interfaces_has_scipy():
37-
assert hasattr(mfi, 'scipy_fft')
38-
39-
4032
@pytest.mark.parametrize('norm', [None, "forward", "backward", "ortho"])
4133
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.complex64, np.complex128])
4234
def test_scipy_fft(norm, dtype):
@@ -151,3 +143,15 @@ def test_scipy_fft_arg_validate():
151143
with pytest.raises(NotImplementedError):
152144
mfi.scipy_fft.fft([1,2,3,4], plan="magic")
153145

146+
147+
@pytest.mark.parametrize(
148+
"func",
149+
[mfi.scipy_fft.rfft2, mfi.numpy_fft.rfft2],
150+
ids=["scipy", "numpy"],
151+
)
152+
def test_axes(func):
153+
x = np.arange(24.).reshape(2, 3, 4)
154+
res = func(x, axes=(1, 2))
155+
exp = np.fft.rfft2(x, axes=(1, 2))
156+
tol = 64 * np.finfo(np.float64).eps
157+
assert np.allclose(res, exp, atol=tol, rtol=tol)

0 commit comments

Comments
 (0)