Skip to content

Commit 0098901

Browse files
committed
Copy over scipy.fft helper functions
to avoid circular dependencies, also vendor (i)fftshift and (r)fftfreq from scipy
1 parent aab894a commit 0098901

File tree

3 files changed

+94
-52
lines changed

3 files changed

+94
-52
lines changed

mkl_fft/interfaces/_scipy_fft.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
import mkl
3838
import numpy as np
39-
import scipy
4039

4140
import mkl_fft
4241

@@ -62,10 +61,6 @@
6261
"ihfft2",
6362
"hfftn",
6463
"ihfftn",
65-
"fftshift",
66-
"ifftshift",
67-
"fftfreq",
68-
"rfftfreq",
6964
"get_workers",
7065
"set_workers",
7166
]
@@ -655,49 +650,6 @@ def ihfftn(
655650
return result
656651

657652

658-
# define thin wrappers for scipy functions to avoid circular dependencies
659-
def fftfreq(n, d=1.0, *, xp=None, device=None):
660-
"""
661-
Return the Discrete Fourier Transform sample frequencies.
662-
663-
For full documentation refer to `scipy.fft.fftfreq`.
664-
665-
"""
666-
return scipy.fft.fftfreq(n, d=d, xp=xp, device=device)
667-
668-
669-
def rfftfreq(n, d=1.0, *, xp=None, device=None):
670-
"""
671-
Return the Discrete Fourier Transform sample frequencies (for usage with
672-
`rfft`, `irfft`).
673-
674-
For full documentation refer to `scipy.fft.rfftfreq`.
675-
676-
"""
677-
return scipy.fft.rfftfreq(n, d=d, xp=xp, device=device)
678-
679-
680-
def fftshift(x, axes=None):
681-
"""
682-
Shift the zero-frequency component to the center of the spectrum.
683-
684-
For full documentation refer to `scipy.fft.fftshift`.
685-
686-
"""
687-
return scipy.fft.fftshift(x, axes=axes)
688-
689-
690-
def ifftshift(x, axes=None):
691-
"""
692-
The inverse of `fftshift`. Although identical for even-length `x`, the
693-
functions differ by one sample for odd-length `x`.
694-
695-
For full documentation refer to `scipy.fft.ifftshift`.
696-
697-
"""
698-
return scipy.fft.ifftshift(x, axes=axes)
699-
700-
701653
def get_workers():
702654
"""
703655
Gets the number of workers used by mkl_fft by default.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2017, Intel Corporation
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of Intel Corporation nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
"""
28+
FFT helper functions copied from `scipy.fft` (with some modification) to
29+
prevent circular dependencies when patching NumPy.
30+
"""
31+
32+
import numpy as np
33+
from scipy._lib._array_api import array_namespace
34+
35+
__all__ = ["fftshift", "ifftshift", "fftfreq", "rfftfreq"]
36+
37+
38+
def fftfreq(n, d=1.0, *, xp=None, device=None):
39+
"""
40+
Return the Discrete Fourier Transform sample frequencies.
41+
42+
For full documentation refer to `scipy.fft.fftfreq`.
43+
44+
"""
45+
xp = np if xp is None else xp
46+
if hasattr(xp, "fft"):
47+
return xp.fft.fftfreq(n, d=d, device=device)
48+
return np.fft.fftfreq(n, d=d, device=device)
49+
50+
51+
def rfftfreq(n, d=1.0, *, xp=None, device=None):
52+
"""
53+
Return the Discrete Fourier Transform sample frequencies (for usage with
54+
`rfft`, `irfft`).
55+
56+
For full documentation refer to `scipy.fft.rfftfreq`.
57+
58+
"""
59+
xp = np if xp is None else xp
60+
if hasattr(xp, "fft"):
61+
return xp.fft.rfftfreq(n, d=d, device=device)
62+
return np.fft.rfftfreq(n, d=d, device=device)
63+
64+
65+
def fftshift(x, axes=None):
66+
"""
67+
Shift the zero-frequency component to the center of the spectrum.
68+
69+
For full documentation refer to `scipy.fft.fftshift`.
70+
71+
"""
72+
xp = array_namespace(x)
73+
if hasattr(xp, "fft"):
74+
return xp.fft.fftshift(x, axes=axes)
75+
x = np.asarray(x)
76+
y = np.fft.fftshift(x, axes=axes)
77+
return xp.asarray(y)
78+
79+
80+
def ifftshift(x, axes=None):
81+
"""
82+
The inverse of `fftshift`. Although identical for even-length `x`, the
83+
functions differ by one sample for odd-length `x`.
84+
85+
For full documentation refer to `scipy.fft.ifftshift`.
86+
87+
"""
88+
xp = array_namespace(x)
89+
if hasattr(xp, "fft"):
90+
return xp.fft.ifftshift(x, axes=axes)
91+
x = np.asarray(x)
92+
y = np.fft.ifftshift(x, axes=axes)
93+
return xp.asarray(y)

mkl_fft/interfaces/scipy_fft.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,14 @@
2828
from ._scipy_fft import (
2929
fft,
3030
fft2,
31-
fftfreq,
3231
fftn,
33-
fftshift,
3432
get_workers,
3533
hfft,
3634
hfft2,
3735
hfftn,
3836
ifft,
3937
ifft2,
4038
ifftn,
41-
ifftshift,
4239
ihfft,
4340
ihfft2,
4441
ihfftn,
@@ -47,10 +44,10 @@
4744
irfftn,
4845
rfft,
4946
rfft2,
50-
rfftfreq,
5147
rfftn,
5248
set_workers,
5349
)
50+
from ._scipy_helper import fftfreq, fftshift, ifftshift, rfftfreq
5451

5552
__all__ = [
5653
"fft",

0 commit comments

Comments
 (0)