Skip to content

Commit 6968167

Browse files
authored
Merge pull request #229 from IntelPython/further-work-to-prevent-circular-dependencies
Add `_numpy_helper.py` to implement `numpy.fft` helper functions
2 parents 0259921 + 62d7e03 commit 6968167

File tree

6 files changed

+217
-60
lines changed

6 files changed

+217
-60
lines changed

mkl_fft/interfaces/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525

2626
from . import numpy_fft
2727

28+
# find scipy, not scipy.fft, to avoid circular dependency
2829
try:
29-
import scipy.fft
30+
import scipy
3031
except ImportError:
3132
pass
3233
else:
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2017, Intel Corporation
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of Intel Corporation nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
"""
28+
FFT helper functions copied from `numpy.fft` (with some modification) to
29+
prevent circular dependencies when patching NumPy.
30+
"""
31+
32+
import numpy as np
33+
34+
__all__ = ["fftshift", "ifftshift", "fftfreq", "rfftfreq"]
35+
36+
37+
def fftshift(x, axes=None):
38+
"""
39+
Shift the zero-frequency component to the center of the spectrum.
40+
41+
For full documentation refer to `numpy.fft.fftshift`.
42+
43+
"""
44+
x = np.asarray(x)
45+
if axes is None:
46+
axes = tuple(range(x.ndim))
47+
shift = [dim // 2 for dim in x.shape]
48+
elif isinstance(axes, (int, np.integer)):
49+
shift = x.shape[axes] // 2
50+
else:
51+
shift = [x.shape[ax] // 2 for ax in axes]
52+
53+
return np.roll(x, shift, axes)
54+
55+
56+
def ifftshift(x, axes=None):
57+
"""
58+
The inverse of `fftshift`. Although identical for even-length `x`, the
59+
functions differ by one sample for odd-length `x`.
60+
61+
For full documentation refer to `numpy.fft.ifftshift`.
62+
63+
"""
64+
x = np.asarray(x)
65+
if axes is None:
66+
axes = tuple(range(x.ndim))
67+
shift = [-(dim // 2) for dim in x.shape]
68+
elif isinstance(axes, (int, np.integer)):
69+
shift = -(x.shape[axes] // 2)
70+
else:
71+
shift = [-(x.shape[ax] // 2) for ax in axes]
72+
73+
return np.roll(x, shift, axes)
74+
75+
76+
def fftfreq(n, d=1.0, device=None):
77+
"""
78+
Return the Discrete Fourier Transform sample frequencies.
79+
80+
For full documentation refer to `numpy.fft.fftfreq`.
81+
82+
"""
83+
if not isinstance(n, (int, np.integer)):
84+
raise ValueError("n should be an integer")
85+
val = 1.0 / (n * d)
86+
# pylint: disable=unexpected-keyword-arg
87+
results = np.empty(n, int, device=device)
88+
# pylint: enable=unexpected-keyword-arg
89+
N = (n - 1) // 2 + 1
90+
# pylint: disable=unexpected-keyword-arg
91+
p1 = np.arange(0, N, dtype=int, device=device)
92+
# pylint: enable=unexpected-keyword-arg
93+
results[:N] = p1
94+
# pylint: disable=unexpected-keyword-arg
95+
p2 = np.arange(-(n // 2), 0, dtype=int, device=device)
96+
# pylint: enable=unexpected-keyword-arg
97+
results[N:] = p2
98+
return results * val
99+
100+
101+
def rfftfreq(n, d=1.0, device=None):
102+
"""
103+
Return the Discrete Fourier Transform sample frequencies (for usage with
104+
`rfft`, `irfft`).
105+
106+
For full documentation refer to `numpy.fft.rfftfreq`.
107+
108+
"""
109+
if not isinstance(n, (int, np.integer)):
110+
raise ValueError("n should be an integer")
111+
val = 1.0 / (n * d)
112+
N = n // 2 + 1
113+
# pylint: disable=unexpected-keyword-arg
114+
results = np.arange(0, N, dtype=int, device=device)
115+
# pylint: enable=unexpected-keyword-arg
116+
return results * val

mkl_fft/interfaces/_scipy_fft.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
import mkl
3838
import numpy as np
39-
import scipy
4039

4140
import mkl_fft
4241

@@ -62,10 +61,6 @@
6261
"ihfft2",
6362
"hfftn",
6463
"ihfftn",
65-
"fftshift",
66-
"ifftshift",
67-
"fftfreq",
68-
"rfftfreq",
6964
"get_workers",
7065
"set_workers",
7166
]
@@ -655,49 +650,6 @@ def ihfftn(
655650
return result
656651

657652

658-
# define thin wrappers for scipy functions to avoid circular dependencies
659-
def fftfreq(n, d=1.0, *, xp=None, device=None):
660-
"""
661-
Return the Discrete Fourier Transform sample frequencies.
662-
663-
For full documentation refer to `scipy.fft.fftfreq`.
664-
665-
"""
666-
return scipy.fft.fftfreq(n, d=d, xp=xp, device=device)
667-
668-
669-
def rfftfreq(n, d=1.0, *, xp=None, device=None):
670-
"""
671-
Return the Discrete Fourier Transform sample frequencies (for usage with
672-
`rfft`, `irfft`).
673-
674-
For full documentation refer to `scipy.fft.rfftfreq`.
675-
676-
"""
677-
return scipy.fft.rfftfreq(n, d=d, xp=xp, device=device)
678-
679-
680-
def fftshift(x, axes=None):
681-
"""
682-
Shift the zero-frequency component to the center of the spectrum.
683-
684-
For full documentation refer to `scipy.fft.fftshift`.
685-
686-
"""
687-
return scipy.fft.fftshift(x, axes=axes)
688-
689-
690-
def ifftshift(x, axes=None):
691-
"""
692-
The inverse of `fftshift`. Although identical for even-length `x`, the
693-
functions differ by one sample for odd-length `x`.
694-
695-
For full documentation refer to `scipy.fft.ifftshift`.
696-
697-
"""
698-
return scipy.fft.ifftshift(x, axes=axes)
699-
700-
701653
def get_workers():
702654
"""
703655
Gets the number of workers used by mkl_fft by default.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2017, Intel Corporation
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of Intel Corporation nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
"""
28+
FFT helper functions copied from `scipy.fft` (with some modification) to
29+
prevent circular dependencies when patching NumPy.
30+
"""
31+
32+
import numpy as np
33+
from scipy._lib._array_api import array_namespace
34+
35+
__all__ = ["fftshift", "ifftshift", "fftfreq", "rfftfreq"]
36+
37+
38+
def fftfreq(n, d=1.0, *, xp=None, device=None):
39+
"""
40+
Return the Discrete Fourier Transform sample frequencies.
41+
42+
For full documentation refer to `scipy.fft.fftfreq`.
43+
44+
"""
45+
xp = np if xp is None else xp
46+
if hasattr(xp, "fft"):
47+
return xp.fft.fftfreq(n, d=d, device=device)
48+
return np.fft.fftfreq(n, d=d, device=device)
49+
50+
51+
def rfftfreq(n, d=1.0, *, xp=None, device=None):
52+
"""
53+
Return the Discrete Fourier Transform sample frequencies (for usage with
54+
`rfft`, `irfft`).
55+
56+
For full documentation refer to `scipy.fft.rfftfreq`.
57+
58+
"""
59+
xp = np if xp is None else xp
60+
if hasattr(xp, "fft"):
61+
return xp.fft.rfftfreq(n, d=d, device=device)
62+
return np.fft.rfftfreq(n, d=d, device=device)
63+
64+
65+
def fftshift(x, axes=None):
66+
"""
67+
Shift the zero-frequency component to the center of the spectrum.
68+
69+
For full documentation refer to `scipy.fft.fftshift`.
70+
71+
"""
72+
xp = array_namespace(x)
73+
if hasattr(xp, "fft"):
74+
return xp.fft.fftshift(x, axes=axes)
75+
x = np.asarray(x)
76+
y = np.fft.fftshift(x, axes=axes)
77+
return xp.asarray(y)
78+
79+
80+
def ifftshift(x, axes=None):
81+
"""
82+
The inverse of `fftshift`. Although identical for even-length `x`, the
83+
functions differ by one sample for odd-length `x`.
84+
85+
For full documentation refer to `scipy.fft.ifftshift`.
86+
87+
"""
88+
xp = array_namespace(x)
89+
if hasattr(xp, "fft"):
90+
return xp.fft.ifftshift(x, axes=axes)
91+
x = np.asarray(x)
92+
y = np.fft.ifftshift(x, axes=axes)
93+
return xp.asarray(y)

mkl_fft/interfaces/numpy_fft.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
rfft2,
4242
rfftn,
4343
)
44+
from ._numpy_helper import fftfreq, fftshift, ifftshift, rfftfreq
4445

4546
__all__ = [
4647
"fft",
@@ -57,11 +58,8 @@
5758
"irfftn",
5859
"hfft",
5960
"ihfft",
61+
"fftshift",
62+
"fftfreq",
63+
"rfftfreq",
64+
"ifftshift",
6065
]
61-
62-
# It is important to put the following import here to avoid circular imports
63-
# when patching numpy with mkl_fft
64-
# Added for completing the namespaces
65-
from numpy.fft import fftfreq, fftshift, ifftshift, rfftfreq
66-
67-
__all__ += ["fftshift", "ifftshift", "fftfreq", "rfftfreq"]

mkl_fft/interfaces/scipy_fft.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,14 @@
2828
from ._scipy_fft import (
2929
fft,
3030
fft2,
31-
fftfreq,
3231
fftn,
33-
fftshift,
3432
get_workers,
3533
hfft,
3634
hfft2,
3735
hfftn,
3836
ifft,
3937
ifft2,
4038
ifftn,
41-
ifftshift,
4239
ihfft,
4340
ihfft2,
4441
ihfftn,
@@ -47,10 +44,10 @@
4744
irfftn,
4845
rfft,
4946
rfft2,
50-
rfftfreq,
5147
rfftn,
5248
set_workers,
5349
)
50+
from ._scipy_helper import fftfreq, fftshift, ifftshift, rfftfreq
5451

5552
__all__ = [
5653
"fft",

0 commit comments

Comments
 (0)