Skip to content

Commit aab894a

Browse files
committed
Copy over numpy.fft helper functions
1 parent 3cb11f3 commit aab894a

File tree

3 files changed

+115
-8
lines changed

3 files changed

+115
-8
lines changed

mkl_fft/interfaces/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525

2626
from . import numpy_fft
2727

28+
# find scipy, not scipy.fft, to avoid circular dependency
2829
try:
29-
import scipy.fft
30+
import scipy
3031
except ImportError:
3132
pass
3233
else:
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2017, Intel Corporation
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of Intel Corporation nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
"""
28+
FFT helper functions copied from `numpy.fft` (with some modification) to
29+
prevent circular dependencies when patching NumPy.
30+
"""
31+
32+
import numpy as np
33+
34+
__all__ = ["fftshift", "ifftshift", "fftfreq", "rfftfreq"]
35+
36+
37+
def fftshift(x, axes=None):
38+
"""
39+
Shift the zero-frequency component to the center of the spectrum.
40+
41+
For full documentation refer to `numpy.fft.fftshift`.
42+
43+
"""
44+
x = np.asarray(x)
45+
if axes is None:
46+
axes = tuple(range(x.ndim))
47+
shift = [dim // 2 for dim in x.shape]
48+
elif isinstance(axes, (int, np.integer)):
49+
shift = x.shape[axes] // 2
50+
else:
51+
shift = [x.shape[ax] // 2 for ax in axes]
52+
53+
return np.roll(x, shift, axes)
54+
55+
56+
def ifftshift(x, axes=None):
57+
"""
58+
The inverse of `fftshift`. Although identical for even-length `x`, the
59+
functions differ by one sample for odd-length `x`.
60+
61+
For full documentation refer to `numpy.fft.ifftshift`.
62+
63+
"""
64+
x = np.asarray(x)
65+
if axes is None:
66+
axes = tuple(range(x.ndim))
67+
shift = [-(dim // 2) for dim in x.shape]
68+
elif isinstance(axes, (int, np.integer)):
69+
shift = -(x.shape[axes] // 2)
70+
else:
71+
shift = [-(x.shape[ax] // 2) for ax in axes]
72+
73+
return np.roll(x, shift, axes)
74+
75+
76+
def fftfreq(n, d=1.0, device=None):
77+
"""
78+
Return the Discrete Fourier Transform sample frequencies.
79+
80+
For full documentation refer to `numpy.fft.fftfreq`.
81+
82+
"""
83+
if not isinstance(n, (int, np.integer)):
84+
raise ValueError("n should be an integer")
85+
val = 1.0 / (n * d)
86+
results = np.empty(n, int, device=device)
87+
N = (n - 1) // 2 + 1
88+
p1 = np.arange(0, N, dtype=int, device=device)
89+
results[:N] = p1
90+
p2 = np.arange(-(n // 2), 0, dtype=int, device=device)
91+
results[N:] = p2
92+
return results * val
93+
94+
95+
def rfftfreq(n, d=1.0, device=None):
96+
"""
97+
Return the Discrete Fourier Transform sample frequencies (for usage with
98+
`rfft`, `irfft`).
99+
100+
For full documentation refer to `numpy.fft.rfftfreq`.
101+
102+
"""
103+
if not isinstance(n, (int, np.integer)):
104+
raise ValueError("n should be an integer")
105+
val = 1.0 / (n * d)
106+
N = n // 2 + 1
107+
results = np.arange(0, N, dtype=int, device=device)
108+
return results * val

mkl_fft/interfaces/numpy_fft.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
rfft2,
4242
rfftn,
4343
)
44+
from ._numpy_helper import fftfreq, fftshift, ifftshift, rfftfreq
4445

4546
__all__ = [
4647
"fft",
@@ -57,11 +58,8 @@
5758
"irfftn",
5859
"hfft",
5960
"ihfft",
61+
"fftshift",
62+
"fftfreq",
63+
"rfftfreq",
64+
"ifftshift",
6065
]
61-
62-
# It is important to put the following import here to avoid circular imports
63-
# when patching numpy with mkl_fft
64-
# Added for completing the namespaces
65-
from numpy.fft import fftfreq, fftshift, ifftshift, rfftfreq
66-
67-
__all__ += ["fftshift", "ifftshift", "fftfreq", "rfftfreq"]

0 commit comments

Comments
 (0)