Skip to content

Commit 227a1b8

Browse files
committed
Copy over scipy.fft helper functions
to avoid circular dependencies, also vendor (i)fftshift and (r)fftfreq from scipy
1 parent aab894a commit 227a1b8

File tree

3 files changed

+95
-52
lines changed

3 files changed

+95
-52
lines changed

mkl_fft/interfaces/_scipy_fft.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
import mkl
3838
import numpy as np
39-
import scipy
4039

4140
import mkl_fft
4241

@@ -62,10 +61,6 @@
6261
"ihfft2",
6362
"hfftn",
6463
"ihfftn",
65-
"fftshift",
66-
"ifftshift",
67-
"fftfreq",
68-
"rfftfreq",
6964
"get_workers",
7065
"set_workers",
7166
]
@@ -655,49 +650,6 @@ def ihfftn(
655650
return result
656651

657652

658-
# define thin wrappers for scipy functions to avoid circular dependencies
659-
def fftfreq(n, d=1.0, *, xp=None, device=None):
660-
"""
661-
Return the Discrete Fourier Transform sample frequencies.
662-
663-
For full documentation refer to `scipy.fft.fftfreq`.
664-
665-
"""
666-
return scipy.fft.fftfreq(n, d=d, xp=xp, device=device)
667-
668-
669-
def rfftfreq(n, d=1.0, *, xp=None, device=None):
670-
"""
671-
Return the Discrete Fourier Transform sample frequencies (for usage with
672-
`rfft`, `irfft`).
673-
674-
For full documentation refer to `scipy.fft.rfftfreq`.
675-
676-
"""
677-
return scipy.fft.rfftfreq(n, d=d, xp=xp, device=device)
678-
679-
680-
def fftshift(x, axes=None):
681-
"""
682-
Shift the zero-frequency component to the center of the spectrum.
683-
684-
For full documentation refer to `scipy.fft.fftshift`.
685-
686-
"""
687-
return scipy.fft.fftshift(x, axes=axes)
688-
689-
690-
def ifftshift(x, axes=None):
691-
"""
692-
The inverse of `fftshift`. Although identical for even-length `x`, the
693-
functions differ by one sample for odd-length `x`.
694-
695-
For full documentation refer to `scipy.fft.ifftshift`.
696-
697-
"""
698-
return scipy.fft.ifftshift(x, axes=axes)
699-
700-
701653
def get_workers():
702654
"""
703655
Gets the number of workers used by mkl_fft by default.
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#!/usr/bin/env python
2+
# Copyright (c) 2017, Intel Corporation
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of Intel Corporation nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
20+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
"""
28+
FFT helper functions copied from `scipy.fft` (with some modification) to
29+
prevent circular dependencies when patching NumPy.
30+
"""
31+
32+
import numpy as np
33+
34+
from scipy._lib._array_api import array_namespace
35+
36+
__all__ = ["fftshift", "ifftshift", "fftfreq", "rfftfreq"]
37+
38+
39+
def fftfreq(n, d=1.0, *, xp=None, device=None):
40+
"""
41+
Return the Discrete Fourier Transform sample frequencies.
42+
43+
For full documentation refer to `scipy.fft.fftfreq`.
44+
45+
"""
46+
xp = np if xp is None else xp
47+
if hasattr(xp, "fft"):
48+
return xp.fft.fftfreq(n, d=d, device=device)
49+
return np.fft.fftfreq(n, d=d, device=device)
50+
51+
52+
def rfftfreq(n, d=1.0, *, xp=None, device=None):
53+
"""
54+
Return the Discrete Fourier Transform sample frequencies (for usage with
55+
`rfft`, `irfft`).
56+
57+
For full documentation refer to `scipy.fft.rfftfreq`.
58+
59+
"""
60+
xp = np if xp is None else xp
61+
if hasattr(xp, "fft"):
62+
return xp.fft.rfftfreq(n, d=d, device=device)
63+
return np.fft.rfftfreq(n, d=d, device=device)
64+
65+
66+
def fftshift(x, axes=None):
67+
"""
68+
Shift the zero-frequency component to the center of the spectrum.
69+
70+
For full documentation refer to `scipy.fft.fftshift`.
71+
72+
"""
73+
xp = array_namespace(x)
74+
if hasattr(xp, "fft"):
75+
return xp.fft.fftshift(x, axes=axes)
76+
x = np.asarray(x)
77+
y = np.fft.fftshift(x, axes=axes)
78+
return xp.asarray(y)
79+
80+
81+
def ifftshift(x, axes=None):
82+
"""
83+
The inverse of `fftshift`. Although identical for even-length `x`, the
84+
functions differ by one sample for odd-length `x`.
85+
86+
For full documentation refer to `scipy.fft.ifftshift`.
87+
88+
"""
89+
xp = array_namespace(x)
90+
if hasattr(xp, "fft"):
91+
return xp.fft.ifftshift(x, axes=axes)
92+
x = np.asarray(x)
93+
y = np.fft.ifftshift(x, axes=axes)
94+
return xp.asarray(y)

mkl_fft/interfaces/scipy_fft.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,14 @@
2828
from ._scipy_fft import (
2929
fft,
3030
fft2,
31-
fftfreq,
3231
fftn,
33-
fftshift,
3432
get_workers,
3533
hfft,
3634
hfft2,
3735
hfftn,
3836
ifft,
3937
ifft2,
4038
ifftn,
41-
ifftshift,
4239
ihfft,
4340
ihfft2,
4441
ihfftn,
@@ -47,10 +44,10 @@
4744
irfftn,
4845
rfft,
4946
rfft2,
50-
rfftfreq,
5147
rfftn,
5248
set_workers,
5349
)
50+
from ._scipy_helper import fftfreq, fftshift, ifftshift, rfftfreq
5451

5552
__all__ = [
5653
"fft",

0 commit comments

Comments
 (0)