Skip to content

Commit 28ab191

Browse files
committed
Copy over numpy.fft helper functions
1 parent 3cb11f3 commit 28ab191

File tree

4 files changed

+117
-9
lines changed

4 files changed

+117
-9
lines changed

mkl_fft/interfaces/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525

2626
from . import numpy_fft
2727

28+
# find scipy, not scipy.fft, to avoid circular dependency
2829
try:
29-
import scipy.fft
30+
import scipy
3031
except ImportError:
3132
pass
3233
else:

mkl_fft/interfaces/_numpy_fft.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
"ihfft",
5757
]
5858

59-
6059
# copied with modifications from:
6160
# https://github.com/numpy/numpy/blob/main/numpy/fft/_pocketfft.py
6261
def _cook_nd_args(a, s=None, axes=None, invreal=False):
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2017, Intel Corporation
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of Intel Corporation nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
"""
28+
FFT hekper functions copied from `numpy.fft` (with some modification) to
29+
prevent circular dependencies when patching NumPy.
30+
"""
31+
32+
import numpy as np
33+
34+
35+
__all__ = ["fftshift", "ifftshift", "fftfreq", "rfftfreq"]
36+
37+
38+
def fftshift(x, axes=None):
39+
"""
40+
Shift the zero-frequency component to the center of the spectrum.
41+
42+
For full documentation refer to `numpy.fft.fftshift`.
43+
44+
"""
45+
x = np.asarray(x)
46+
if axes is None:
47+
axes = tuple(range(x.ndim))
48+
shift = [dim // 2 for dim in x.shape]
49+
elif isinstance(axes, (int, np.integer)):
50+
shift = x.shape[axes] // 2
51+
else:
52+
shift = [x.shape[ax] // 2 for ax in axes]
53+
54+
return np.roll(x, shift, axes)
55+
56+
57+
def ifftshift(x, axes=None):
58+
"""
59+
The inverse of `fftshift`. Although identical for even-length `x`, the
60+
functions differ by one sample for odd-length `x`.
61+
62+
For full documentation refer to `numpy.fft.ifftshift`.
63+
64+
"""
65+
x = np.asarray(x)
66+
if axes is None:
67+
axes = tuple(range(x.ndim))
68+
shift = [-(dim // 2) for dim in x.shape]
69+
elif isinstance(axes, (int, np.integer)):
70+
shift = -(x.shape[axes] // 2)
71+
else:
72+
shift = [-(x.shape[ax] // 2) for ax in axes]
73+
74+
return np.roll(x, shift, axes)
75+
76+
77+
def fftfreq(n, d=1.0, device=None):
78+
"""
79+
Return the Discrete Fourier Transform sample frequencies.
80+
81+
For full documentation refer to `numpy.fft.fftfreq`.
82+
83+
"""
84+
if not isinstance(n, (int, np.integer)):
85+
raise ValueError("n should be an integer")
86+
val = 1.0 / (n * d)
87+
results = np.empty(n, int, device=device)
88+
N = (n - 1) // 2 + 1
89+
p1 = np.arange(0, N, dtype=int, device=device)
90+
results[:N] = p1
91+
p2 = np.arange(-(n // 2), 0, dtype=int, device=device)
92+
results[N:] = p2
93+
return results * val
94+
95+
96+
def rfftfreq(n, d=1.0, device=None):
97+
"""
98+
Return the Discrete Fourier Transform sample frequencies (for usage with
99+
`rfft`, `irfft`).
100+
101+
For full documentation refer to `numpy.fft.rfftfreq`.
102+
103+
"""
104+
if not isinstance(n, (int, np.integer)):
105+
raise ValueError("n should be an integer")
106+
val = 1.0 / (n * d)
107+
N = n // 2 + 1
108+
results = np.arange(0, N, dtype=int, device=device)
109+
return results * val

mkl_fft/interfaces/numpy_fft.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
rfftn,
4343
)
4444

45+
from ._numpy_helper import fftfreq, fftshift, ifftshift, rfftfreq
46+
4547
__all__ = [
4648
"fft",
4749
"ifft",
@@ -57,11 +59,8 @@
5759
"irfftn",
5860
"hfft",
5961
"ihfft",
62+
"fftshift",
63+
"fftfreq",
64+
"rfftfreq",
65+
"ifftshift",
6066
]
61-
62-
# It is important to put the following import here to avoid circular imports
63-
# when patching numpy with mkl_fft
64-
# Added for completing the namespaces
65-
from numpy.fft import fftfreq, fftshift, ifftshift, rfftfreq
66-
67-
__all__ += ["fftshift", "ifftshift", "fftfreq", "rfftfreq"]

0 commit comments

Comments
 (0)