Skip to content

Commit eda7834

Browse files
Fixed rfft/irfft, rfftn/irfftn roundtrip test for interfaces.scipy_fft
1 parent 69bfdbb commit eda7834

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

mkl_fft/_scipy_fft_backend.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -338,21 +338,22 @@ def irfft(a, n=None, axis=-1, norm=None, workers=None, plan=None):
338338
return NotImplemented
339339
if x is NotImplemented:
340340
return x
341-
fsc = _compute_1d_forward_scale(norm, n, x.shape[axis])
341+
nn = n if n else 2*(x.shape[axis]-1)
342+
fsc = _compute_1d_forward_scale(norm, nn, x.shape[axis])
342343
_check_plan(plan)
343344
with Workers(workers):
344345
output = _pydfti.irfft_numpy(x, n=n, axis=axis, forward_scale=fsc)
345346
return output
346347

347348

348-
def _compute_nd_forward_scale_for_rfft(norm, s, axes, x):
349+
def _compute_nd_forward_scale_for_rfft(norm, s, axes, x, invreal=False):
349350
if norm in (None, "backward"):
350351
fsc = 1.0
351352
elif norm == "forward":
352-
s, axes = _cook_nd_args(x, s, axes)
353+
s, axes = _cook_nd_args(x, s, axes, invreal=invreal)
353354
fsc = _frwd_sc_nd(s, axes, x.shape)
354355
elif norm == "ortho":
355-
s, axes = _cook_nd_args(x, s, axes)
356+
s, axes = _cook_nd_args(x, s, axes, invreal=invreal)
356357
fsc = sqrt(_frwd_sc_nd(s, axes, x.shape))
357358
else:
358359
_check_norm(norm)
@@ -380,7 +381,7 @@ def irfft2(a, s=None, axes=(-2, -1), norm=None, workers=None, plan=None):
380381
return NotImplemented
381382
if x is NotImplemented:
382383
return x
383-
s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x)
384+
s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x, invreal=True)
384385
_check_plan(plan)
385386
with Workers(workers):
386387
output = _pydfti.irfftn_numpy(x, s, axes, forward_scale=fsc)
@@ -408,7 +409,7 @@ def irfftn(a, s=None, axes=None, norm=None, workers=None, plan=None):
408409
return NotImplemented
409410
if x is NotImplemented:
410411
return x
411-
s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x)
412+
s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x, invreal=True)
412413
_check_plan(plan)
413414
with Workers(workers):
414415
output = _pydfti.irfftn_numpy(x, s, axes, forward_scale=fsc)

0 commit comments

Comments
 (0)