Skip to content

Commit 45ae1a5

Browse files
committed
address comments
1 parent a445c5c commit 45ae1a5

File tree

4 files changed

+65
-120
lines changed

4 files changed

+65
-120
lines changed

mkl_fft/_float_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __downcast_float128_array(x):
8181

8282
def __supported_array_or_not_implemented(x):
8383
"""
84-
Used in _scipy_fft_backend to convert array to float32,
84+
Used in _scipy_fft to convert array to float32,
8585
float64, complex64, or complex128 type or return NotImplemented
8686
"""
8787
__x = np.asarray(x)

mkl_fft/_numpy_fft.py

Lines changed: 56 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -74,35 +74,44 @@
7474
import warnings
7575

7676
import numpy as np
77-
from numpy import array, asanyarray, conjugate, prod, sqrt, take
77+
from numpy import array, conjugate, prod, sqrt, take
7878

7979
from . import _float_utils
8080
from . import _pydfti as mkl_fft # pylint: disable=no-name-in-module
8181

8282

83+
def _compute_fwd_scale(norm, n, shape):
84+
_check_norm(norm)
85+
if norm in (None, "backward"):
86+
return 1.0
87+
88+
ss = n if n is not None else shape
89+
nn = prod(ss)
90+
fsc = 1 / nn if nn != 0 else 1
91+
if norm == "forward":
92+
return fsc
93+
else: # norm == "ortho"
94+
return sqrt(fsc)
95+
96+
8397
def _check_norm(norm):
8498
if norm not in (None, "ortho", "forward", "backward"):
8599
raise ValueError(
86-
(
87-
"Invalid norm value {} should be None, "
88-
'"ortho", "forward", or "backward".'
89-
).format(norm)
100+
f"Invalid norm value {norm} should be None, 'ortho', 'forward', "
101+
"or 'backward'."
90102
)
91103

92104

93-
def frwd_sc_1d(n, s):
94-
nn = n if n is not None else s
95-
return 1 / nn if nn != 0 else 1
96-
97-
98-
def frwd_sc_nd(s, x_shape):
99-
ss = s if s is not None else x_shape
100-
nn = prod(ss)
101-
return 1 / nn if nn != 0 else 1
102-
105+
def _swap_direction(norm):
106+
_check_norm(norm)
107+
_swap_direction_map = {
108+
"backward": "forward",
109+
None: "forward",
110+
"ortho": "ortho",
111+
"forward": "backward",
112+
}
103113

104-
def ortho_sc_1d(n, s):
105-
return sqrt(frwd_sc_1d(n, s))
114+
return _swap_direction_map[norm]
106115

107116

108117
def trycall(func, args, kwrds):
@@ -208,15 +217,9 @@ def fft(a, n=None, axis=-1, norm=None):
208217
the `numpy.fft` documentation.
209218
210219
"""
211-
_check_norm(norm)
212-
x = _float_utils.__downcast_float128_array(a)
213220

214-
if norm in (None, "backward"):
215-
fsc = 1.0
216-
elif norm == "forward":
217-
fsc = frwd_sc_1d(n, x.shape[axis])
218-
else:
219-
fsc = ortho_sc_1d(n, x.shape[axis])
221+
x = _float_utils.__downcast_float128_array(a)
222+
fsc = _compute_fwd_scale(norm, n, x.shape[axis])
220223

221224
return trycall(mkl_fft.fft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc})
222225

@@ -307,15 +310,9 @@ def ifft(a, n=None, axis=-1, norm=None):
307310
>>> plt.show()
308311
309312
"""
310-
_check_norm(norm)
311-
x = _float_utils.__downcast_float128_array(a)
312313

313-
if norm in (None, "backward"):
314-
fsc = 1.0
315-
elif norm == "forward":
316-
fsc = frwd_sc_1d(n, x.shape[axis])
317-
else:
318-
fsc = ortho_sc_1d(n, x.shape[axis])
314+
x = _float_utils.__downcast_float128_array(a)
315+
fsc = _compute_fwd_scale(norm, n, x.shape[axis])
319316

320317
return trycall(mkl_fft.ifft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc})
321318

@@ -404,15 +401,9 @@ def rfft(a, n=None, axis=-1, norm=None):
404401
exploited to compute only the non-negative frequency terms.
405402
406403
"""
407-
_check_norm(norm)
408-
x = _float_utils.__downcast_float128_array(a)
409404

410-
if norm in (None, "backward"):
411-
fsc = 1.0
412-
elif norm == "forward":
413-
fsc = frwd_sc_1d(n, x.shape[axis])
414-
else:
415-
fsc = ortho_sc_1d(n, x.shape[axis])
405+
x = _float_utils.__downcast_float128_array(a)
406+
fsc = _compute_fwd_scale(norm, n, x.shape[axis])
416407

417408
return trycall(mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc})
418409

@@ -503,16 +494,9 @@ def irfft(a, n=None, axis=-1, norm=None):
503494
specified, and the output array is purely real.
504495
505496
"""
506-
_check_norm(norm)
507-
x = _float_utils.__downcast_float128_array(a)
508497

509-
nn = n if n else 2 * (x.shape[axis] - 1)
510-
if norm in (None, "backward"):
511-
fsc = 1.0
512-
elif norm == "forward":
513-
fsc = frwd_sc_1d(nn, nn)
514-
else:
515-
fsc = ortho_sc_1d(nn, nn)
498+
x = _float_utils.__downcast_float128_array(a)
499+
fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1))
516500

517501
return trycall(
518502
mkl_fft.irfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc}
@@ -595,18 +579,12 @@ def hfft(a, n=None, axis=-1, norm=None):
595579
[ 2., -2.]])
596580
597581
"""
598-
_check_norm(norm)
582+
583+
norm = _swap_direction(norm)
599584
x = _float_utils.__downcast_float128_array(a)
600585
x = array(x, copy=True, dtype=complex)
601586
conjugate(x, out=x)
602-
603-
nn = n if n else 2 * (x.shape[axis] - 1)
604-
if norm in (None, "backward"):
605-
fsc = frwd_sc_1d(nn, nn)
606-
elif norm == "forward":
607-
fsc = 1.0
608-
else:
609-
fsc = ortho_sc_1d(nn, nn)
587+
fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1))
610588

611589
return trycall(
612590
mkl_fft.irfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc}
@@ -670,17 +648,12 @@ def ihfft(a, n=None, axis=-1, norm=None):
670648
array([ 1.-0.j, 2.-0.j, 3.-0.j, 4.-0.j])
671649
672650
"""
651+
673652
# The copy may be required for multithreading.
674-
_check_norm(norm)
653+
norm = _swap_direction(norm)
675654
x = _float_utils.__downcast_float128_array(a)
676655
x = array(x, copy=True, dtype=float)
677-
678-
if norm in (None, "backward"):
679-
fsc = frwd_sc_1d(n, x.shape[axis])
680-
elif norm == "forward":
681-
fsc = 1.0
682-
else:
683-
fsc = ortho_sc_1d(n, x.shape[axis])
656+
fsc = _compute_fwd_scale(norm, n, x.shape[axis])
684657

685658
output = trycall(
686659
mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc}
@@ -832,16 +805,10 @@ def fftn(a, s=None, axes=None, norm=None):
832805
>>> plt.show()
833806
834807
"""
835-
_check_norm(norm)
808+
836809
x = _float_utils.__downcast_float128_array(a)
837810
s, axes = _cook_nd_args(x, s, axes)
838-
839-
if norm in (None, "backward"):
840-
fsc = 1.0
841-
elif norm == "forward":
842-
fsc = frwd_sc_nd(s, x.shape)
843-
else:
844-
fsc = sqrt(frwd_sc_nd(s, x.shape))
811+
fsc = _compute_fwd_scale(norm, s, x.shape)
845812

846813
return trycall(mkl_fft.fftn, (x,), {"s": s, "axes": axes, "fwd_scale": fsc})
847814

@@ -945,16 +912,10 @@ def ifftn(a, s=None, axes=None, norm=None):
945912
>>> plt.show()
946913
947914
"""
948-
_check_norm(norm)
915+
949916
x = _float_utils.__downcast_float128_array(a)
950917
s, axes = _cook_nd_args(x, s, axes)
951-
952-
if norm in (None, "backward"):
953-
fsc = 1.0
954-
elif norm == "forward":
955-
fsc = frwd_sc_nd(s, x.shape)
956-
else:
957-
fsc = sqrt(frwd_sc_nd(s, x.shape))
918+
fsc = _compute_fwd_scale(norm, s, x.shape)
958919

959920
return trycall(
960921
mkl_fft.ifftn, (x,), {"s": s, "axes": axes, "fwd_scale": fsc}
@@ -1053,9 +1014,8 @@ def fft2(a, s=None, axes=(-2, -1), norm=None):
10531014
0.0 +0.j , 0.0 +0.j ]])
10541015
10551016
"""
1056-
_check_norm(norm)
1057-
x = _float_utils.__downcast_float128_array(a)
1058-
return fftn(x, s=s, axes=axes, norm=norm)
1017+
1018+
return fftn(a, s=s, axes=axes, norm=norm)
10591019

10601020

10611021
def ifft2(a, s=None, axes=(-2, -1), norm=None):
@@ -1147,9 +1107,8 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None):
11471107
[ 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j]])
11481108
11491109
"""
1150-
_check_norm(norm)
1151-
x = _float_utils.__downcast_float128_array(a)
1152-
return ifftn(x, s=s, axes=axes, norm=norm)
1110+
1111+
return ifftn(a, s=s, axes=axes, norm=norm)
11531112

11541113

11551114
def rfftn(a, s=None, axes=None, norm=None):
@@ -1241,18 +1200,10 @@ def rfftn(a, s=None, axes=None, norm=None):
12411200
[ 0.+0.j, 0.+0.j]]])
12421201
12431202
"""
1244-
_check_norm(norm)
1203+
12451204
x = _float_utils.__downcast_float128_array(a)
12461205
s, axes = _cook_nd_args(x, s, axes)
1247-
1248-
if norm in (None, "backward"):
1249-
fsc = 1.0
1250-
elif norm == "forward":
1251-
x = asanyarray(x)
1252-
fsc = frwd_sc_nd(s, x.shape)
1253-
else:
1254-
x = asanyarray(x)
1255-
fsc = sqrt(frwd_sc_nd(s, x.shape))
1206+
fsc = _compute_fwd_scale(norm, s, x.shape)
12561207

12571208
return trycall(
12581209
mkl_fft.rfftn, (x,), {"s": s, "axes": axes, "fwd_scale": fsc}
@@ -1298,9 +1249,8 @@ def rfft2(a, s=None, axes=(-2, -1), norm=None):
12981249
For more details see `rfftn`.
12991250
13001251
"""
1301-
_check_norm(norm)
1302-
x = _float_utils.__downcast_float128_array(a)
1303-
return rfftn(x, s, axes, norm)
1252+
1253+
return rfftn(a, s, axes, norm)
13041254

13051255

13061256
def irfftn(a, s=None, axes=None, norm=None):
@@ -1394,18 +1344,10 @@ def irfftn(a, s=None, axes=None, norm=None):
13941344
[ 1., 1.]]])
13951345
13961346
"""
1397-
_check_norm(norm)
1347+
13981348
x = _float_utils.__downcast_float128_array(a)
13991349
s, axes = _cook_nd_args(x, s, axes, invreal=True)
1400-
1401-
if norm in (None, "backward"):
1402-
fsc = 1.0
1403-
elif norm == "forward":
1404-
x = asanyarray(x)
1405-
fsc = frwd_sc_nd(s, x.shape)
1406-
else:
1407-
x = asanyarray(x)
1408-
fsc = sqrt(frwd_sc_nd(s, x.shape))
1350+
fsc = _compute_fwd_scale(norm, s, x.shape)
14091351

14101352
return trycall(
14111353
mkl_fft.irfftn, (x,), {"s": s, "axes": axes, "fwd_scale": fsc}
@@ -1451,6 +1393,5 @@ def irfft2(a, s=None, axes=(-2, -1), norm=None):
14511393
For more details see `irfftn`.
14521394
14531395
"""
1454-
_check_norm(norm)
1455-
x = _float_utils.__downcast_float128_array(a)
1456-
return irfftn(x, s, axes, norm)
1396+
1397+
return irfftn(a, s, axes, norm)

mkl_fft/_scipy_fft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
4444
:Example:
4545
import scipy.fft
46-
import mkl_fft._scipy_fft_backend as be
46+
import mkl_fft._scipy_fft as be
4747
# Set mkl_fft to be used as backend of SciPy's FFT functions.
4848
scipy.fft.set_global_backend(be)
4949
"""

mkl_fft/tests/test_pocketfft.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
import mkl_fft.interfaces.numpy_fft as mkl_fft
2121

22+
requires_numpy_2 = pytest.mark.skipif(
23+
np.__version__ < "2.0", reason="Requires NumPy >= 2.0"
24+
)
25+
2226

2327
def fft1(x):
2428
L = len(x)
@@ -510,7 +514,7 @@ def test_s_negative_1(self, op):
510514
# should use the whole input array along the first axis
511515
assert op(x, s=(-1, 5), axes=(0, 1)).shape == (10, 5)
512516

513-
@pytest.mark.skipif(np.__version__ < "2.0", reason="Requires numpy >= 2.0")
517+
@requires_numpy_2
514518
@pytest.mark.parametrize(
515519
"op", [mkl_fft.fftn, mkl_fft.ifftn, mkl_fft.rfftn, mkl_fft.irfftn]
516520
)
@@ -519,14 +523,14 @@ def test_s_axes_none(self, op):
519523
with pytest.warns(match="`axes` should not be `None` if `s`"):
520524
op(x, s=(-1, 5))
521525

522-
@pytest.mark.skipif(np.__version__ < "2.0", reason="Requires numpy >= 2.0")
526+
@requires_numpy_2
523527
@pytest.mark.parametrize("op", [mkl_fft.fft2, mkl_fft.ifft2])
524528
def test_s_axes_none_2D(self, op):
525529
x = np.arange(100).reshape(10, 10)
526530
with pytest.warns(match="`axes` should not be `None` if `s`"):
527531
op(x, s=(-1, 5), axes=None)
528532

529-
@pytest.mark.skipif(np.__version__ < "2.0", reason="Requires numpy >= 2.0")
533+
@requires_numpy_2
530534
@pytest.mark.parametrize(
531535
"op",
532536
[

0 commit comments

Comments
 (0)