From aab894af4dcf24e75bb91ed946f6ce9f226594fb Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 30 Sep 2025 10:58:56 -0700 Subject: [PATCH 1/3] Copy over numpy.fft helper functions --- mkl_fft/interfaces/__init__.py | 3 +- mkl_fft/interfaces/_numpy_helper.py | 108 ++++++++++++++++++++++++++++ mkl_fft/interfaces/numpy_fft.py | 12 ++-- 3 files changed, 115 insertions(+), 8 deletions(-) create mode 100644 mkl_fft/interfaces/_numpy_helper.py diff --git a/mkl_fft/interfaces/__init__.py b/mkl_fft/interfaces/__init__.py index 1988ba8..ff17c4b 100644 --- a/mkl_fft/interfaces/__init__.py +++ b/mkl_fft/interfaces/__init__.py @@ -25,8 +25,9 @@ from . import numpy_fft +# find scipy, not scipy.fft, to avoid circular dependency try: - import scipy.fft + import scipy except ImportError: pass else: diff --git a/mkl_fft/interfaces/_numpy_helper.py b/mkl_fft/interfaces/_numpy_helper.py new file mode 100644 index 0000000..4f961b5 --- /dev/null +++ b/mkl_fft/interfaces/_numpy_helper.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python +# Copyright (c) 2017, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +FFT helper functions copied from `numpy.fft` (with some modification) to +prevent circular dependencies when patching NumPy. +""" + +import numpy as np + +__all__ = ["fftshift", "ifftshift", "fftfreq", "rfftfreq"] + + +def fftshift(x, axes=None): + """ + Shift the zero-frequency component to the center of the spectrum. + + For full documentation refer to `numpy.fft.fftshift`. + + """ + x = np.asarray(x) + if axes is None: + axes = tuple(range(x.ndim)) + shift = [dim // 2 for dim in x.shape] + elif isinstance(axes, (int, np.integer)): + shift = x.shape[axes] // 2 + else: + shift = [x.shape[ax] // 2 for ax in axes] + + return np.roll(x, shift, axes) + + +def ifftshift(x, axes=None): + """ + The inverse of `fftshift`. Although identical for even-length `x`, the + functions differ by one sample for odd-length `x`. + + For full documentation refer to `numpy.fft.ifftshift`. + + """ + x = np.asarray(x) + if axes is None: + axes = tuple(range(x.ndim)) + shift = [-(dim // 2) for dim in x.shape] + elif isinstance(axes, (int, np.integer)): + shift = -(x.shape[axes] // 2) + else: + shift = [-(x.shape[ax] // 2) for ax in axes] + + return np.roll(x, shift, axes) + + +def fftfreq(n, d=1.0, device=None): + """ + Return the Discrete Fourier Transform sample frequencies. + + For full documentation refer to `numpy.fft.fftfreq`. + + """ + if not isinstance(n, (int, np.integer)): + raise ValueError("n should be an integer") + val = 1.0 / (n * d) + results = np.empty(n, int, device=device) + N = (n - 1) // 2 + 1 + p1 = np.arange(0, N, dtype=int, device=device) + results[:N] = p1 + p2 = np.arange(-(n // 2), 0, dtype=int, device=device) + results[N:] = p2 + return results * val + + +def rfftfreq(n, d=1.0, device=None): + """ + Return the Discrete Fourier Transform sample frequencies (for usage with + `rfft`, `irfft`). + + For full documentation refer to `numpy.fft.rfftfreq`. + + """ + if not isinstance(n, (int, np.integer)): + raise ValueError("n should be an integer") + val = 1.0 / (n * d) + N = n // 2 + 1 + results = np.arange(0, N, dtype=int, device=device) + return results * val diff --git a/mkl_fft/interfaces/numpy_fft.py b/mkl_fft/interfaces/numpy_fft.py index 5653ff1..aa74f3d 100644 --- a/mkl_fft/interfaces/numpy_fft.py +++ b/mkl_fft/interfaces/numpy_fft.py @@ -41,6 +41,7 @@ rfft2, rfftn, ) +from ._numpy_helper import fftfreq, fftshift, ifftshift, rfftfreq __all__ = [ "fft", @@ -57,11 +58,8 @@ "irfftn", "hfft", "ihfft", + "fftshift", + "fftfreq", + "rfftfreq", + "ifftshift", ] - -# It is important to put the following import here to avoid circular imports -# when patching numpy with mkl_fft -# Added for completing the namespaces -from numpy.fft import fftfreq, fftshift, ifftshift, rfftfreq - -__all__ += ["fftshift", "ifftshift", "fftfreq", "rfftfreq"] From 0098901349ba7e2d0faebb86992d460c3b63f4e3 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 30 Sep 2025 12:17:44 -0700 Subject: [PATCH 2/3] Copy over scipy.fft helper functions to avoid circular dependencies, also vendor (i)fftshift and (r)fftfreq from scipy --- mkl_fft/interfaces/_scipy_fft.py | 48 --------------- mkl_fft/interfaces/_scipy_helper.py | 93 +++++++++++++++++++++++++++++ mkl_fft/interfaces/scipy_fft.py | 5 +- 3 files changed, 94 insertions(+), 52 deletions(-) create mode 100644 mkl_fft/interfaces/_scipy_helper.py diff --git a/mkl_fft/interfaces/_scipy_fft.py b/mkl_fft/interfaces/_scipy_fft.py index 910f652..77429c7 100644 --- a/mkl_fft/interfaces/_scipy_fft.py +++ b/mkl_fft/interfaces/_scipy_fft.py @@ -36,7 +36,6 @@ import mkl import numpy as np -import scipy import mkl_fft @@ -62,10 +61,6 @@ "ihfft2", "hfftn", "ihfftn", - "fftshift", - "ifftshift", - "fftfreq", - "rfftfreq", "get_workers", "set_workers", ] @@ -655,49 +650,6 @@ def ihfftn( return result -# define thin wrappers for scipy functions to avoid circular dependencies -def fftfreq(n, d=1.0, *, xp=None, device=None): - """ - Return the Discrete Fourier Transform sample frequencies. - - For full documentation refer to `scipy.fft.fftfreq`. - - """ - return scipy.fft.fftfreq(n, d=d, xp=xp, device=device) - - -def rfftfreq(n, d=1.0, *, xp=None, device=None): - """ - Return the Discrete Fourier Transform sample frequencies (for usage with - `rfft`, `irfft`). - - For full documentation refer to `scipy.fft.rfftfreq`. - - """ - return scipy.fft.rfftfreq(n, d=d, xp=xp, device=device) - - -def fftshift(x, axes=None): - """ - Shift the zero-frequency component to the center of the spectrum. - - For full documentation refer to `scipy.fft.fftshift`. - - """ - return scipy.fft.fftshift(x, axes=axes) - - -def ifftshift(x, axes=None): - """ - The inverse of `fftshift`. Although identical for even-length `x`, the - functions differ by one sample for odd-length `x`. - - For full documentation refer to `scipy.fft.ifftshift`. - - """ - return scipy.fft.ifftshift(x, axes=axes) - - def get_workers(): """ Gets the number of workers used by mkl_fft by default. diff --git a/mkl_fft/interfaces/_scipy_helper.py b/mkl_fft/interfaces/_scipy_helper.py new file mode 100644 index 0000000..7f75c30 --- /dev/null +++ b/mkl_fft/interfaces/_scipy_helper.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# Copyright (c) 2017, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +FFT helper functions copied from `scipy.fft` (with some modification) to +prevent circular dependencies when patching NumPy. +""" + +import numpy as np +from scipy._lib._array_api import array_namespace + +__all__ = ["fftshift", "ifftshift", "fftfreq", "rfftfreq"] + + +def fftfreq(n, d=1.0, *, xp=None, device=None): + """ + Return the Discrete Fourier Transform sample frequencies. + + For full documentation refer to `scipy.fft.fftfreq`. + + """ + xp = np if xp is None else xp + if hasattr(xp, "fft"): + return xp.fft.fftfreq(n, d=d, device=device) + return np.fft.fftfreq(n, d=d, device=device) + + +def rfftfreq(n, d=1.0, *, xp=None, device=None): + """ + Return the Discrete Fourier Transform sample frequencies (for usage with + `rfft`, `irfft`). + + For full documentation refer to `scipy.fft.rfftfreq`. + + """ + xp = np if xp is None else xp + if hasattr(xp, "fft"): + return xp.fft.rfftfreq(n, d=d, device=device) + return np.fft.rfftfreq(n, d=d, device=device) + + +def fftshift(x, axes=None): + """ + Shift the zero-frequency component to the center of the spectrum. + + For full documentation refer to `scipy.fft.fftshift`. + + """ + xp = array_namespace(x) + if hasattr(xp, "fft"): + return xp.fft.fftshift(x, axes=axes) + x = np.asarray(x) + y = np.fft.fftshift(x, axes=axes) + return xp.asarray(y) + + +def ifftshift(x, axes=None): + """ + The inverse of `fftshift`. Although identical for even-length `x`, the + functions differ by one sample for odd-length `x`. + + For full documentation refer to `scipy.fft.ifftshift`. + + """ + xp = array_namespace(x) + if hasattr(xp, "fft"): + return xp.fft.ifftshift(x, axes=axes) + x = np.asarray(x) + y = np.fft.ifftshift(x, axes=axes) + return xp.asarray(y) diff --git a/mkl_fft/interfaces/scipy_fft.py b/mkl_fft/interfaces/scipy_fft.py index 4adce52..0f12841 100644 --- a/mkl_fft/interfaces/scipy_fft.py +++ b/mkl_fft/interfaces/scipy_fft.py @@ -28,9 +28,7 @@ from ._scipy_fft import ( fft, fft2, - fftfreq, fftn, - fftshift, get_workers, hfft, hfft2, @@ -38,7 +36,6 @@ ifft, ifft2, ifftn, - ifftshift, ihfft, ihfft2, ihfftn, @@ -47,10 +44,10 @@ irfftn, rfft, rfft2, - rfftfreq, rfftn, set_workers, ) +from ._scipy_helper import fftfreq, fftshift, ifftshift, rfftfreq __all__ = [ "fft", From 0802c25b7d9b5b56a551413670a740dafa259463 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 30 Sep 2025 12:19:01 -0700 Subject: [PATCH 3/3] fix linting of reimplemented numpy helpers numpy constructor argument signatures were updated with a device keyword in numpy 2 and pylint fails as a result --- mkl_fft/interfaces/_numpy_helper.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mkl_fft/interfaces/_numpy_helper.py b/mkl_fft/interfaces/_numpy_helper.py index 4f961b5..1a67812 100644 --- a/mkl_fft/interfaces/_numpy_helper.py +++ b/mkl_fft/interfaces/_numpy_helper.py @@ -83,11 +83,17 @@ def fftfreq(n, d=1.0, device=None): if not isinstance(n, (int, np.integer)): raise ValueError("n should be an integer") val = 1.0 / (n * d) + # pylint: disable=unexpected-keyword-arg results = np.empty(n, int, device=device) + # pylint: enable=unexpected-keyword-arg N = (n - 1) // 2 + 1 + # pylint: disable=unexpected-keyword-arg p1 = np.arange(0, N, dtype=int, device=device) + # pylint: enable=unexpected-keyword-arg results[:N] = p1 + # pylint: disable=unexpected-keyword-arg p2 = np.arange(-(n // 2), 0, dtype=int, device=device) + # pylint: enable=unexpected-keyword-arg results[N:] = p2 return results * val @@ -104,5 +110,7 @@ def rfftfreq(n, d=1.0, device=None): raise ValueError("n should be an integer") val = 1.0 / (n * d) N = n // 2 + 1 + # pylint: disable=unexpected-keyword-arg results = np.arange(0, N, dtype=int, device=device) + # pylint: enable=unexpected-keyword-arg return results * val