Skip to content

Commit 004ad93

Browse files
committed
update _scipy_fftpack.py
1 parent fd85ef8 commit 004ad93

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

mkl_fft/_scipy_fftpack.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,47 +24,47 @@
2424
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27-
from . import _float_utils
2827
from . import _pydfti as mkl_fft # pylint: disable=no-name-in-module
28+
from ._float_utils import __upcast_float16_array
2929

3030
__all__ = ["fft", "ifft", "fftn", "ifftn", "fft2", "ifft2", "rfft", "irfft"]
3131

3232

3333
def fft(a, n=None, axis=-1, overwrite_x=False):
34-
x = _float_utils.__upcast_float16_array(a)
34+
x = __upcast_float16_array(a)
3535
return mkl_fft.fft(x, n=n, axis=axis, overwrite_x=overwrite_x)
3636

3737

3838
def ifft(a, n=None, axis=-1, overwrite_x=False):
39-
x = _float_utils.__upcast_float16_array(a)
39+
x = __upcast_float16_array(a)
4040
return mkl_fft.ifft(x, n=n, axis=axis, overwrite_x=overwrite_x)
4141

4242

4343
def fftn(a, shape=None, axes=None, overwrite_x=False):
44-
x = _float_utils.__upcast_float16_array(a)
45-
return mkl_fft.fftn(x, shape=shape, axes=axes, overwrite_x=overwrite_x)
44+
x = __upcast_float16_array(a)
45+
return mkl_fft.fftn(x, s=shape, axes=axes, overwrite_x=overwrite_x)
4646

4747

4848
def ifftn(a, shape=None, axes=None, overwrite_x=False):
49-
x = _float_utils.__upcast_float16_array(a)
50-
return mkl_fft.ifftn(x, shape=shape, axes=axes, overwrite_x=overwrite_x)
49+
x = __upcast_float16_array(a)
50+
return mkl_fft.ifftn(x, s=shape, axes=axes, overwrite_x=overwrite_x)
5151

5252

5353
def fft2(a, shape=None, axes=(-2, -1), overwrite_x=False):
54-
x = _float_utils.__upcast_float16_array(a)
55-
return mkl_fft.fftn(x, shape=shape, axes=axes, overwrite_x=overwrite_x)
54+
x = __upcast_float16_array(a)
55+
return mkl_fft.fftn(x, s=shape, axes=axes, overwrite_x=overwrite_x)
5656

5757

5858
def ifft2(a, shape=None, axes=(-2, -1), overwrite_x=False):
59-
x = _float_utils.__upcast_float16_array(a)
60-
return mkl_fft.ifftn(x, shape=shape, axes=axes, overwrite_x=overwrite_x)
59+
x = __upcast_float16_array(a)
60+
return mkl_fft.ifftn(x, s=shape, axes=axes, overwrite_x=overwrite_x)
6161

6262

6363
def rfft(a, n=None, axis=-1, overwrite_x=False):
64-
x = _float_utils.__upcast_float16_array(a)
64+
x = __upcast_float16_array(a)
6565
return mkl_fft.rfftpack(x, n=n, axis=axis, overwrite_x=overwrite_x)
6666

6767

6868
def irfft(a, n=None, axis=-1, overwrite_x=False):
69-
x = _float_utils.__upcast_float16_array(a)
69+
x = __upcast_float16_array(a)
7070
return mkl_fft.irfftpack(x, n=n, axis=axis, overwrite_x=overwrite_x)

0 commit comments

Comments
 (0)