|
74 | 74 | import warnings
|
75 | 75 |
|
76 | 76 | import numpy as np
|
77 |
| -from numpy import array, asanyarray, conjugate, prod, sqrt, take |
| 77 | +from numpy import array, conjugate, prod, sqrt, take |
78 | 78 |
|
79 | 79 | from . import _float_utils
|
80 | 80 | from . import _pydfti as mkl_fft # pylint: disable=no-name-in-module
|
81 | 81 |
|
82 | 82 |
|
| 83 | +def _compute_fwd_scale(norm, n, shape): |
| 84 | + _check_norm(norm) |
| 85 | + if norm in (None, "backward"): |
| 86 | + return 1.0 |
| 87 | + |
| 88 | + ss = n if n is not None else shape |
| 89 | + nn = prod(ss) |
| 90 | + fsc = 1 / nn if nn != 0 else 1 |
| 91 | + if norm == "forward": |
| 92 | + return fsc |
| 93 | + else: # norm == "ortho" |
| 94 | + return sqrt(fsc) |
| 95 | + |
| 96 | + |
83 | 97 | def _check_norm(norm):
|
84 | 98 | if norm not in (None, "ortho", "forward", "backward"):
|
85 | 99 | raise ValueError(
|
86 |
| - ( |
87 |
| - "Invalid norm value {} should be None, " |
88 |
| - '"ortho", "forward", or "backward".' |
89 |
| - ).format(norm) |
| 100 | + f"Invalid norm value {norm} should be None, 'ortho', 'forward', " |
| 101 | + "or 'backward'." |
90 | 102 | )
|
91 | 103 |
|
92 | 104 |
|
93 |
| -def frwd_sc_1d(n, s): |
94 |
| - nn = n if n is not None else s |
95 |
| - return 1 / nn if nn != 0 else 1 |
96 |
| - |
97 |
| - |
98 |
| -def frwd_sc_nd(s, x_shape): |
99 |
| - ss = s if s is not None else x_shape |
100 |
| - nn = prod(ss) |
101 |
| - return 1 / nn if nn != 0 else 1 |
102 |
| - |
| 105 | +def _swap_direction(norm): |
| 106 | + _check_norm(norm) |
| 107 | + _swap_direction_map = { |
| 108 | + "backward": "forward", |
| 109 | + None: "forward", |
| 110 | + "ortho": "ortho", |
| 111 | + "forward": "backward", |
| 112 | + } |
103 | 113 |
|
104 |
| -def ortho_sc_1d(n, s): |
105 |
| - return sqrt(frwd_sc_1d(n, s)) |
| 114 | + return _swap_direction_map[norm] |
106 | 115 |
|
107 | 116 |
|
108 | 117 | def trycall(func, args, kwrds):
|
@@ -208,15 +217,9 @@ def fft(a, n=None, axis=-1, norm=None):
|
208 | 217 | the `numpy.fft` documentation.
|
209 | 218 |
|
210 | 219 | """
|
211 |
| - _check_norm(norm) |
212 |
| - x = _float_utils.__downcast_float128_array(a) |
213 | 220 |
|
214 |
| - if norm in (None, "backward"): |
215 |
| - fsc = 1.0 |
216 |
| - elif norm == "forward": |
217 |
| - fsc = frwd_sc_1d(n, x.shape[axis]) |
218 |
| - else: |
219 |
| - fsc = ortho_sc_1d(n, x.shape[axis]) |
| 221 | + x = _float_utils.__downcast_float128_array(a) |
| 222 | + fsc = _compute_fwd_scale(norm, n, x.shape[axis]) |
220 | 223 |
|
221 | 224 | return trycall(mkl_fft.fft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc})
|
222 | 225 |
|
@@ -307,15 +310,9 @@ def ifft(a, n=None, axis=-1, norm=None):
|
307 | 310 | >>> plt.show()
|
308 | 311 |
|
309 | 312 | """
|
310 |
| - _check_norm(norm) |
311 |
| - x = _float_utils.__downcast_float128_array(a) |
312 | 313 |
|
313 |
| - if norm in (None, "backward"): |
314 |
| - fsc = 1.0 |
315 |
| - elif norm == "forward": |
316 |
| - fsc = frwd_sc_1d(n, x.shape[axis]) |
317 |
| - else: |
318 |
| - fsc = ortho_sc_1d(n, x.shape[axis]) |
| 314 | + x = _float_utils.__downcast_float128_array(a) |
| 315 | + fsc = _compute_fwd_scale(norm, n, x.shape[axis]) |
319 | 316 |
|
320 | 317 | return trycall(mkl_fft.ifft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc})
|
321 | 318 |
|
@@ -404,15 +401,9 @@ def rfft(a, n=None, axis=-1, norm=None):
|
404 | 401 | exploited to compute only the non-negative frequency terms.
|
405 | 402 |
|
406 | 403 | """
|
407 |
| - _check_norm(norm) |
408 |
| - x = _float_utils.__downcast_float128_array(a) |
409 | 404 |
|
410 |
| - if norm in (None, "backward"): |
411 |
| - fsc = 1.0 |
412 |
| - elif norm == "forward": |
413 |
| - fsc = frwd_sc_1d(n, x.shape[axis]) |
414 |
| - else: |
415 |
| - fsc = ortho_sc_1d(n, x.shape[axis]) |
| 405 | + x = _float_utils.__downcast_float128_array(a) |
| 406 | + fsc = _compute_fwd_scale(norm, n, x.shape[axis]) |
416 | 407 |
|
417 | 408 | return trycall(mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc})
|
418 | 409 |
|
@@ -503,16 +494,9 @@ def irfft(a, n=None, axis=-1, norm=None):
|
503 | 494 | specified, and the output array is purely real.
|
504 | 495 |
|
505 | 496 | """
|
506 |
| - _check_norm(norm) |
507 |
| - x = _float_utils.__downcast_float128_array(a) |
508 | 497 |
|
509 |
| - nn = n if n else 2 * (x.shape[axis] - 1) |
510 |
| - if norm in (None, "backward"): |
511 |
| - fsc = 1.0 |
512 |
| - elif norm == "forward": |
513 |
| - fsc = frwd_sc_1d(nn, nn) |
514 |
| - else: |
515 |
| - fsc = ortho_sc_1d(nn, nn) |
| 498 | + x = _float_utils.__downcast_float128_array(a) |
| 499 | + fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1)) |
516 | 500 |
|
517 | 501 | return trycall(
|
518 | 502 | mkl_fft.irfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc}
|
@@ -595,18 +579,12 @@ def hfft(a, n=None, axis=-1, norm=None):
|
595 | 579 | [ 2., -2.]])
|
596 | 580 |
|
597 | 581 | """
|
598 |
| - _check_norm(norm) |
| 582 | + |
| 583 | + norm = _swap_direction(norm) |
599 | 584 | x = _float_utils.__downcast_float128_array(a)
|
600 | 585 | x = array(x, copy=True, dtype=complex)
|
601 | 586 | conjugate(x, out=x)
|
602 |
| - |
603 |
| - nn = n if n else 2 * (x.shape[axis] - 1) |
604 |
| - if norm in (None, "backward"): |
605 |
| - fsc = frwd_sc_1d(nn, nn) |
606 |
| - elif norm == "forward": |
607 |
| - fsc = 1.0 |
608 |
| - else: |
609 |
| - fsc = ortho_sc_1d(nn, nn) |
| 587 | + fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1)) |
610 | 588 |
|
611 | 589 | return trycall(
|
612 | 590 | mkl_fft.irfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc}
|
@@ -670,17 +648,12 @@ def ihfft(a, n=None, axis=-1, norm=None):
|
670 | 648 | array([ 1.-0.j, 2.-0.j, 3.-0.j, 4.-0.j])
|
671 | 649 |
|
672 | 650 | """
|
| 651 | + |
673 | 652 | # The copy may be required for multithreading.
|
674 |
| - _check_norm(norm) |
| 653 | + norm = _swap_direction(norm) |
675 | 654 | x = _float_utils.__downcast_float128_array(a)
|
676 | 655 | x = array(x, copy=True, dtype=float)
|
677 |
| - |
678 |
| - if norm in (None, "backward"): |
679 |
| - fsc = frwd_sc_1d(n, x.shape[axis]) |
680 |
| - elif norm == "forward": |
681 |
| - fsc = 1.0 |
682 |
| - else: |
683 |
| - fsc = ortho_sc_1d(n, x.shape[axis]) |
| 656 | + fsc = _compute_fwd_scale(norm, n, x.shape[axis]) |
684 | 657 |
|
685 | 658 | output = trycall(
|
686 | 659 | mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc}
|
@@ -832,16 +805,10 @@ def fftn(a, s=None, axes=None, norm=None):
|
832 | 805 | >>> plt.show()
|
833 | 806 |
|
834 | 807 | """
|
835 |
| - _check_norm(norm) |
| 808 | + |
836 | 809 | x = _float_utils.__downcast_float128_array(a)
|
837 | 810 | s, axes = _cook_nd_args(x, s, axes)
|
838 |
| - |
839 |
| - if norm in (None, "backward"): |
840 |
| - fsc = 1.0 |
841 |
| - elif norm == "forward": |
842 |
| - fsc = frwd_sc_nd(s, x.shape) |
843 |
| - else: |
844 |
| - fsc = sqrt(frwd_sc_nd(s, x.shape)) |
| 811 | + fsc = _compute_fwd_scale(norm, s, x.shape) |
845 | 812 |
|
846 | 813 | return trycall(mkl_fft.fftn, (x,), {"s": s, "axes": axes, "fwd_scale": fsc})
|
847 | 814 |
|
@@ -945,16 +912,10 @@ def ifftn(a, s=None, axes=None, norm=None):
|
945 | 912 | >>> plt.show()
|
946 | 913 |
|
947 | 914 | """
|
948 |
| - _check_norm(norm) |
| 915 | + |
949 | 916 | x = _float_utils.__downcast_float128_array(a)
|
950 | 917 | s, axes = _cook_nd_args(x, s, axes)
|
951 |
| - |
952 |
| - if norm in (None, "backward"): |
953 |
| - fsc = 1.0 |
954 |
| - elif norm == "forward": |
955 |
| - fsc = frwd_sc_nd(s, x.shape) |
956 |
| - else: |
957 |
| - fsc = sqrt(frwd_sc_nd(s, x.shape)) |
| 918 | + fsc = _compute_fwd_scale(norm, s, x.shape) |
958 | 919 |
|
959 | 920 | return trycall(
|
960 | 921 | mkl_fft.ifftn, (x,), {"s": s, "axes": axes, "fwd_scale": fsc}
|
@@ -1053,9 +1014,8 @@ def fft2(a, s=None, axes=(-2, -1), norm=None):
|
1053 | 1014 | 0.0 +0.j , 0.0 +0.j ]])
|
1054 | 1015 |
|
1055 | 1016 | """
|
1056 |
| - _check_norm(norm) |
1057 |
| - x = _float_utils.__downcast_float128_array(a) |
1058 |
| - return fftn(x, s=s, axes=axes, norm=norm) |
| 1017 | + |
| 1018 | + return fftn(a, s=s, axes=axes, norm=norm) |
1059 | 1019 |
|
1060 | 1020 |
|
1061 | 1021 | def ifft2(a, s=None, axes=(-2, -1), norm=None):
|
@@ -1147,9 +1107,8 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None):
|
1147 | 1107 | [ 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j]])
|
1148 | 1108 |
|
1149 | 1109 | """
|
1150 |
| - _check_norm(norm) |
1151 |
| - x = _float_utils.__downcast_float128_array(a) |
1152 |
| - return ifftn(x, s=s, axes=axes, norm=norm) |
| 1110 | + |
| 1111 | + return ifftn(a, s=s, axes=axes, norm=norm) |
1153 | 1112 |
|
1154 | 1113 |
|
1155 | 1114 | def rfftn(a, s=None, axes=None, norm=None):
|
@@ -1241,18 +1200,10 @@ def rfftn(a, s=None, axes=None, norm=None):
|
1241 | 1200 | [ 0.+0.j, 0.+0.j]]])
|
1242 | 1201 |
|
1243 | 1202 | """
|
1244 |
| - _check_norm(norm) |
| 1203 | + |
1245 | 1204 | x = _float_utils.__downcast_float128_array(a)
|
1246 | 1205 | s, axes = _cook_nd_args(x, s, axes)
|
1247 |
| - |
1248 |
| - if norm in (None, "backward"): |
1249 |
| - fsc = 1.0 |
1250 |
| - elif norm == "forward": |
1251 |
| - x = asanyarray(x) |
1252 |
| - fsc = frwd_sc_nd(s, x.shape) |
1253 |
| - else: |
1254 |
| - x = asanyarray(x) |
1255 |
| - fsc = sqrt(frwd_sc_nd(s, x.shape)) |
| 1206 | + fsc = _compute_fwd_scale(norm, s, x.shape) |
1256 | 1207 |
|
1257 | 1208 | return trycall(
|
1258 | 1209 | mkl_fft.rfftn, (x,), {"s": s, "axes": axes, "fwd_scale": fsc}
|
@@ -1298,9 +1249,8 @@ def rfft2(a, s=None, axes=(-2, -1), norm=None):
|
1298 | 1249 | For more details see `rfftn`.
|
1299 | 1250 |
|
1300 | 1251 | """
|
1301 |
| - _check_norm(norm) |
1302 |
| - x = _float_utils.__downcast_float128_array(a) |
1303 |
| - return rfftn(x, s, axes, norm) |
| 1252 | + |
| 1253 | + return rfftn(a, s, axes, norm) |
1304 | 1254 |
|
1305 | 1255 |
|
1306 | 1256 | def irfftn(a, s=None, axes=None, norm=None):
|
@@ -1394,18 +1344,10 @@ def irfftn(a, s=None, axes=None, norm=None):
|
1394 | 1344 | [ 1., 1.]]])
|
1395 | 1345 |
|
1396 | 1346 | """
|
1397 |
| - _check_norm(norm) |
| 1347 | + |
1398 | 1348 | x = _float_utils.__downcast_float128_array(a)
|
1399 | 1349 | s, axes = _cook_nd_args(x, s, axes, invreal=True)
|
1400 |
| - |
1401 |
| - if norm in (None, "backward"): |
1402 |
| - fsc = 1.0 |
1403 |
| - elif norm == "forward": |
1404 |
| - x = asanyarray(x) |
1405 |
| - fsc = frwd_sc_nd(s, x.shape) |
1406 |
| - else: |
1407 |
| - x = asanyarray(x) |
1408 |
| - fsc = sqrt(frwd_sc_nd(s, x.shape)) |
| 1350 | + fsc = _compute_fwd_scale(norm, s, x.shape) |
1409 | 1351 |
|
1410 | 1352 | return trycall(
|
1411 | 1353 | mkl_fft.irfftn, (x,), {"s": s, "axes": axes, "fwd_scale": fsc}
|
@@ -1451,6 +1393,5 @@ def irfft2(a, s=None, axes=(-2, -1), norm=None):
|
1451 | 1393 | For more details see `irfftn`.
|
1452 | 1394 |
|
1453 | 1395 | """
|
1454 |
| - _check_norm(norm) |
1455 |
| - x = _float_utils.__downcast_float128_array(a) |
1456 |
| - return irfftn(x, s, axes, norm) |
| 1396 | + |
| 1397 | + return irfftn(a, s, axes, norm) |
0 commit comments