diff --git a/mkl_fft/interfaces/__init__.py b/mkl_fft/interfaces/__init__.py index 1988ba8..ff17c4b 100644 --- a/mkl_fft/interfaces/__init__.py +++ b/mkl_fft/interfaces/__init__.py @@ -25,8 +25,9 @@ from . import numpy_fft +# find scipy, not scipy.fft, to avoid circular dependency try: - import scipy.fft + import scipy except ImportError: pass else: diff --git a/mkl_fft/interfaces/_numpy_helper.py b/mkl_fft/interfaces/_numpy_helper.py new file mode 100644 index 0000000..1a67812 --- /dev/null +++ b/mkl_fft/interfaces/_numpy_helper.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python +# Copyright (c) 2017, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +FFT helper functions copied from `numpy.fft` (with some modification) to +prevent circular dependencies when patching NumPy. +""" + +import numpy as np + +__all__ = ["fftshift", "ifftshift", "fftfreq", "rfftfreq"] + + +def fftshift(x, axes=None): + """ + Shift the zero-frequency component to the center of the spectrum. + + For full documentation refer to `numpy.fft.fftshift`. + + """ + x = np.asarray(x) + if axes is None: + axes = tuple(range(x.ndim)) + shift = [dim // 2 for dim in x.shape] + elif isinstance(axes, (int, np.integer)): + shift = x.shape[axes] // 2 + else: + shift = [x.shape[ax] // 2 for ax in axes] + + return np.roll(x, shift, axes) + + +def ifftshift(x, axes=None): + """ + The inverse of `fftshift`. Although identical for even-length `x`, the + functions differ by one sample for odd-length `x`. + + For full documentation refer to `numpy.fft.ifftshift`. + + """ + x = np.asarray(x) + if axes is None: + axes = tuple(range(x.ndim)) + shift = [-(dim // 2) for dim in x.shape] + elif isinstance(axes, (int, np.integer)): + shift = -(x.shape[axes] // 2) + else: + shift = [-(x.shape[ax] // 2) for ax in axes] + + return np.roll(x, shift, axes) + + +def fftfreq(n, d=1.0, device=None): + """ + Return the Discrete Fourier Transform sample frequencies. + + For full documentation refer to `numpy.fft.fftfreq`. + + """ + if not isinstance(n, (int, np.integer)): + raise ValueError("n should be an integer") + val = 1.0 / (n * d) + # pylint: disable=unexpected-keyword-arg + results = np.empty(n, int, device=device) + # pylint: enable=unexpected-keyword-arg + N = (n - 1) // 2 + 1 + # pylint: disable=unexpected-keyword-arg + p1 = np.arange(0, N, dtype=int, device=device) + # pylint: enable=unexpected-keyword-arg + results[:N] = p1 + # pylint: disable=unexpected-keyword-arg + p2 = np.arange(-(n // 2), 0, dtype=int, device=device) + # pylint: enable=unexpected-keyword-arg + results[N:] = p2 + return results * val + + +def rfftfreq(n, d=1.0, device=None): + """ + Return the Discrete Fourier Transform sample frequencies (for usage with + `rfft`, `irfft`). + + For full documentation refer to `numpy.fft.rfftfreq`. + + """ + if not isinstance(n, (int, np.integer)): + raise ValueError("n should be an integer") + val = 1.0 / (n * d) + N = n // 2 + 1 + # pylint: disable=unexpected-keyword-arg + results = np.arange(0, N, dtype=int, device=device) + # pylint: enable=unexpected-keyword-arg + return results * val diff --git a/mkl_fft/interfaces/_scipy_fft.py b/mkl_fft/interfaces/_scipy_fft.py index 910f652..77429c7 100644 --- a/mkl_fft/interfaces/_scipy_fft.py +++ b/mkl_fft/interfaces/_scipy_fft.py @@ -36,7 +36,6 @@ import mkl import numpy as np -import scipy import mkl_fft @@ -62,10 +61,6 @@ "ihfft2", "hfftn", "ihfftn", - "fftshift", - "ifftshift", - "fftfreq", - "rfftfreq", "get_workers", "set_workers", ] @@ -655,49 +650,6 @@ def ihfftn( return result -# define thin wrappers for scipy functions to avoid circular dependencies -def fftfreq(n, d=1.0, *, xp=None, device=None): - """ - Return the Discrete Fourier Transform sample frequencies. - - For full documentation refer to `scipy.fft.fftfreq`. - - """ - return scipy.fft.fftfreq(n, d=d, xp=xp, device=device) - - -def rfftfreq(n, d=1.0, *, xp=None, device=None): - """ - Return the Discrete Fourier Transform sample frequencies (for usage with - `rfft`, `irfft`). - - For full documentation refer to `scipy.fft.rfftfreq`. - - """ - return scipy.fft.rfftfreq(n, d=d, xp=xp, device=device) - - -def fftshift(x, axes=None): - """ - Shift the zero-frequency component to the center of the spectrum. - - For full documentation refer to `scipy.fft.fftshift`. - - """ - return scipy.fft.fftshift(x, axes=axes) - - -def ifftshift(x, axes=None): - """ - The inverse of `fftshift`. Although identical for even-length `x`, the - functions differ by one sample for odd-length `x`. - - For full documentation refer to `scipy.fft.ifftshift`. - - """ - return scipy.fft.ifftshift(x, axes=axes) - - def get_workers(): """ Gets the number of workers used by mkl_fft by default. diff --git a/mkl_fft/interfaces/_scipy_helper.py b/mkl_fft/interfaces/_scipy_helper.py new file mode 100644 index 0000000..7f75c30 --- /dev/null +++ b/mkl_fft/interfaces/_scipy_helper.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# Copyright (c) 2017, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +FFT helper functions copied from `scipy.fft` (with some modification) to +prevent circular dependencies when patching NumPy. +""" + +import numpy as np +from scipy._lib._array_api import array_namespace + +__all__ = ["fftshift", "ifftshift", "fftfreq", "rfftfreq"] + + +def fftfreq(n, d=1.0, *, xp=None, device=None): + """ + Return the Discrete Fourier Transform sample frequencies. + + For full documentation refer to `scipy.fft.fftfreq`. + + """ + xp = np if xp is None else xp + if hasattr(xp, "fft"): + return xp.fft.fftfreq(n, d=d, device=device) + return np.fft.fftfreq(n, d=d, device=device) + + +def rfftfreq(n, d=1.0, *, xp=None, device=None): + """ + Return the Discrete Fourier Transform sample frequencies (for usage with + `rfft`, `irfft`). + + For full documentation refer to `scipy.fft.rfftfreq`. + + """ + xp = np if xp is None else xp + if hasattr(xp, "fft"): + return xp.fft.rfftfreq(n, d=d, device=device) + return np.fft.rfftfreq(n, d=d, device=device) + + +def fftshift(x, axes=None): + """ + Shift the zero-frequency component to the center of the spectrum. + + For full documentation refer to `scipy.fft.fftshift`. + + """ + xp = array_namespace(x) + if hasattr(xp, "fft"): + return xp.fft.fftshift(x, axes=axes) + x = np.asarray(x) + y = np.fft.fftshift(x, axes=axes) + return xp.asarray(y) + + +def ifftshift(x, axes=None): + """ + The inverse of `fftshift`. Although identical for even-length `x`, the + functions differ by one sample for odd-length `x`. + + For full documentation refer to `scipy.fft.ifftshift`. + + """ + xp = array_namespace(x) + if hasattr(xp, "fft"): + return xp.fft.ifftshift(x, axes=axes) + x = np.asarray(x) + y = np.fft.ifftshift(x, axes=axes) + return xp.asarray(y) diff --git a/mkl_fft/interfaces/numpy_fft.py b/mkl_fft/interfaces/numpy_fft.py index 5653ff1..aa74f3d 100644 --- a/mkl_fft/interfaces/numpy_fft.py +++ b/mkl_fft/interfaces/numpy_fft.py @@ -41,6 +41,7 @@ rfft2, rfftn, ) +from ._numpy_helper import fftfreq, fftshift, ifftshift, rfftfreq __all__ = [ "fft", @@ -57,11 +58,8 @@ "irfftn", "hfft", "ihfft", + "fftshift", + "fftfreq", + "rfftfreq", + "ifftshift", ] - -# It is important to put the following import here to avoid circular imports -# when patching numpy with mkl_fft -# Added for completing the namespaces -from numpy.fft import fftfreq, fftshift, ifftshift, rfftfreq - -__all__ += ["fftshift", "ifftshift", "fftfreq", "rfftfreq"] diff --git a/mkl_fft/interfaces/scipy_fft.py b/mkl_fft/interfaces/scipy_fft.py index 4adce52..0f12841 100644 --- a/mkl_fft/interfaces/scipy_fft.py +++ b/mkl_fft/interfaces/scipy_fft.py @@ -28,9 +28,7 @@ from ._scipy_fft import ( fft, fft2, - fftfreq, fftn, - fftshift, get_workers, hfft, hfft2, @@ -38,7 +36,6 @@ ifft, ifft2, ifftn, - ifftshift, ihfft, ihfft2, ihfftn, @@ -47,10 +44,10 @@ irfftn, rfft, rfft2, - rfftfreq, rfftn, set_workers, ) +from ._scipy_helper import fftfreq, fftshift, ifftshift, rfftfreq __all__ = [ "fft",