Skip to content

Commit 4f46721

Browse files
author
Vahid Tavanashad
committed
expose mkl_fft.rfft2_numpy and mkl_fft.irfft2_numpy
1 parent daf2815 commit 4f46721

File tree

3 files changed

+21
-4
lines changed

3 files changed

+21
-4
lines changed

mkl_fft/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
from ._pydfti import (fft, ifft, fft2, ifft2, fftn, ifftn, rfft, irfft,
28-
rfft_numpy, irfft_numpy, rfftn_numpy, irfftn_numpy)
28+
rfft_numpy, irfft_numpy, rfft2_numpy, irfft2_numpy,
29+
rfftn_numpy, irfftn_numpy)
2930

3031
from ._version import __version__
3132
import mkl_fft.interfaces
3233

3334
__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn', 'rfft', 'irfft',
34-
'rfft_numpy', 'irfft_numpy', 'rfftn_numpy', 'irfftn_numpy', 'interfaces']
35+
'rfft_numpy', 'irfft_numpy', 'rfft2_numpy', 'irfft2_numpy',
36+
'rfftn_numpy', 'irfftn_numpy', 'interfaces']

mkl_fft/_pydfti.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,11 +1122,11 @@ def ifftn(x, shape=None, axes=None, overwrite_x=False, forward_scale=1.0):
11221122

11231123

11241124
def rfft2_numpy(x, s=None, axes=(-2,-1), forward_scale=1.0):
1125-
return rfftn_numpy(x, s=s, axes=axes, fsc=forward_scale)
1125+
return rfftn_numpy(x, s=s, axes=axes, forward_scale=forward_scale)
11261126

11271127

11281128
def irfft2_numpy(x, s=None, axes=(-2,-1), forward_scale=1.0):
1129-
return irfftn_numpy(x, s=s, axes=axes, fsc=forward_scale)
1129+
return irfftn_numpy(x, s=s, axes=axes, forward_scale=forward_scale)
11301130

11311131

11321132
def _remove_axis(s, axes, axis_to_remove):

mkl_fft/tests/test_fftnd.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,18 @@ def test_gh109():
228228

229229
rtol, atol = _get_rtol_atol(b)
230230
assert_allclose(r1, r2, rtol=rtol, atol=atol)
231+
232+
233+
def test_rfftn_numpy():
234+
x = np.ones((37, 83))
235+
236+
w = mkl_fft.rfftn_numpy(x)
237+
xx = mkl_fft.irfftn_numpy(w, s=x.shape)
238+
tol = 64 * np.finfo(np.dtype(x.dtype)).eps
239+
assert np.allclose(x, xx, atol=tol, rtol=tol)
240+
241+
w = mkl_fft.rfft2_numpy(x)
242+
xx = mkl_fft.irfft2_numpy(w, s=x.shape)
243+
tol = 64 * np.finfo(np.dtype(x.dtype)).eps
244+
assert np.allclose(x, xx, atol=tol, rtol=tol)
245+

0 commit comments

Comments
 (0)