|
74 | 74 | import warnings
|
75 | 75 |
|
76 | 76 | import numpy as np
|
77 |
| -from numpy import array, conjugate, prod, sqrt, take |
78 | 77 |
|
79 |
| -from . import _float_utils |
80 | 78 | from . import _pydfti as mkl_fft # pylint: disable=no-name-in-module
|
| 79 | +from ._fft_utils import _check_norm, _compute_fwd_scale |
| 80 | +from ._float_utils import __downcast_float128_array |
81 | 81 |
|
82 | 82 |
|
83 |
| -def _compute_fwd_scale(norm, n, shape): |
84 |
| - _check_norm(norm) |
85 |
| - if norm in (None, "backward"): |
86 |
| - return 1.0 |
87 |
| - |
88 |
| - ss = n if n is not None else shape |
89 |
| - nn = prod(ss) |
90 |
| - fsc = 1 / nn if nn != 0 else 1 |
91 |
| - if norm == "forward": |
92 |
| - return fsc |
93 |
| - else: # norm == "ortho" |
94 |
| - return sqrt(fsc) |
95 |
| - |
96 |
| - |
97 |
| -def _check_norm(norm): |
98 |
| - if norm not in (None, "ortho", "forward", "backward"): |
99 |
| - raise ValueError( |
100 |
| - f"Invalid norm value {norm} should be None, 'ortho', 'forward', " |
101 |
| - "or 'backward'." |
| 83 | +# copied from: https://github.com/numpy/numpy/blob/main/numpy/fft/_pocketfft.py |
| 84 | +def _cook_nd_args(a, s=None, axes=None, invreal=False): |
| 85 | + if s is None: |
| 86 | + shapeless = True |
| 87 | + if axes is None: |
| 88 | + s = list(a.shape) |
| 89 | + else: |
| 90 | + s = np.take(a.shape, axes) |
| 91 | + else: |
| 92 | + shapeless = False |
| 93 | + s = list(s) |
| 94 | + if axes is None: |
| 95 | + if not shapeless and np.__version__ >= "2.0": |
| 96 | + msg = ( |
| 97 | + "`axes` should not be `None` if `s` is not `None` " |
| 98 | + "(Deprecated in NumPy 2.0). In a future version of NumPy, " |
| 99 | + "this will raise an error and `s[i]` will correspond to " |
| 100 | + "the size along the transformed axis specified by " |
| 101 | + "`axes[i]`. To retain current behaviour, pass a sequence " |
| 102 | + "[0, ..., k-1] to `axes` for an array of dimension k." |
| 103 | + ) |
| 104 | + warnings.warn(msg, DeprecationWarning, stacklevel=3) |
| 105 | + axes = list(range(-len(s), 0)) |
| 106 | + if len(s) != len(axes): |
| 107 | + raise ValueError("Shape and axes have different lengths.") |
| 108 | + if invreal and shapeless: |
| 109 | + s[-1] = (a.shape[axes[-1]] - 1) * 2 |
| 110 | + if None in s and np.__version__ >= "2.0": |
| 111 | + msg = ( |
| 112 | + "Passing an array containing `None` values to `s` is " |
| 113 | + "deprecated in NumPy 2.0 and will raise an error in " |
| 114 | + "a future version of NumPy. To use the default behaviour " |
| 115 | + "of the corresponding 1-D transform, pass the value matching " |
| 116 | + "the default for its `n` parameter. To use the default " |
| 117 | + "behaviour for every axis, the `s` argument can be omitted." |
102 | 118 | )
|
| 119 | + warnings.warn(msg, DeprecationWarning, stacklevel=3) |
| 120 | + # use the whole input array along axis `i` if `s[i] == -1 or None` |
| 121 | + s = [a.shape[_a] if _s in [-1, None] else _s for _s, _a in zip(s, axes)] |
| 122 | + |
| 123 | + return s, axes |
103 | 124 |
|
104 | 125 |
|
105 | 126 | def _swap_direction(norm):
|
@@ -218,7 +239,7 @@ def fft(a, n=None, axis=-1, norm=None):
|
218 | 239 |
|
219 | 240 | """
|
220 | 241 |
|
221 |
| - x = _float_utils.__downcast_float128_array(a) |
| 242 | + x = __downcast_float128_array(a) |
222 | 243 | fsc = _compute_fwd_scale(norm, n, x.shape[axis])
|
223 | 244 |
|
224 | 245 | return trycall(mkl_fft.fft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc})
|
@@ -311,7 +332,7 @@ def ifft(a, n=None, axis=-1, norm=None):
|
311 | 332 |
|
312 | 333 | """
|
313 | 334 |
|
314 |
| - x = _float_utils.__downcast_float128_array(a) |
| 335 | + x = __downcast_float128_array(a) |
315 | 336 | fsc = _compute_fwd_scale(norm, n, x.shape[axis])
|
316 | 337 |
|
317 | 338 | return trycall(mkl_fft.ifft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc})
|
@@ -402,7 +423,7 @@ def rfft(a, n=None, axis=-1, norm=None):
|
402 | 423 |
|
403 | 424 | """
|
404 | 425 |
|
405 |
| - x = _float_utils.__downcast_float128_array(a) |
| 426 | + x = __downcast_float128_array(a) |
406 | 427 | fsc = _compute_fwd_scale(norm, n, x.shape[axis])
|
407 | 428 |
|
408 | 429 | return trycall(mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc})
|
@@ -495,7 +516,7 @@ def irfft(a, n=None, axis=-1, norm=None):
|
495 | 516 |
|
496 | 517 | """
|
497 | 518 |
|
498 |
| - x = _float_utils.__downcast_float128_array(a) |
| 519 | + x = __downcast_float128_array(a) |
499 | 520 | fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1))
|
500 | 521 |
|
501 | 522 | return trycall(
|
@@ -581,9 +602,9 @@ def hfft(a, n=None, axis=-1, norm=None):
|
581 | 602 | """
|
582 | 603 |
|
583 | 604 | norm = _swap_direction(norm)
|
584 |
| - x = _float_utils.__downcast_float128_array(a) |
585 |
| - x = array(x, copy=True, dtype=complex) |
586 |
| - conjugate(x, out=x) |
| 605 | + x = __downcast_float128_array(a) |
| 606 | + x = np.array(x, copy=True, dtype=complex) |
| 607 | + np.conjugate(x, out=x) |
587 | 608 | fsc = _compute_fwd_scale(norm, n, 2 * (x.shape[axis] - 1))
|
588 | 609 |
|
589 | 610 | return trycall(
|
@@ -651,61 +672,18 @@ def ihfft(a, n=None, axis=-1, norm=None):
|
651 | 672 |
|
652 | 673 | # The copy may be required for multithreading.
|
653 | 674 | norm = _swap_direction(norm)
|
654 |
| - x = _float_utils.__downcast_float128_array(a) |
655 |
| - x = array(x, copy=True, dtype=float) |
| 675 | + x = __downcast_float128_array(a) |
| 676 | + x = np.array(x, copy=True, dtype=float) |
656 | 677 | fsc = _compute_fwd_scale(norm, n, x.shape[axis])
|
657 | 678 |
|
658 | 679 | output = trycall(
|
659 | 680 | mkl_fft.rfft, (x,), {"n": n, "axis": axis, "fwd_scale": fsc}
|
660 | 681 | )
|
661 | 682 |
|
662 |
| - conjugate(output, out=output) |
| 683 | + np.conjugate(output, out=output) |
663 | 684 | return output
|
664 | 685 |
|
665 | 686 |
|
666 |
| -# copied from: https://github.com/numpy/numpy/blob/main/numpy/fft/_pocketfft.py |
667 |
| -def _cook_nd_args(a, s=None, axes=None, invreal=False): |
668 |
| - if s is None: |
669 |
| - shapeless = True |
670 |
| - if axes is None: |
671 |
| - s = list(a.shape) |
672 |
| - else: |
673 |
| - s = take(a.shape, axes) |
674 |
| - else: |
675 |
| - shapeless = False |
676 |
| - s = list(s) |
677 |
| - if axes is None: |
678 |
| - if not shapeless and np.__version__ >= "2.0": |
679 |
| - msg = ( |
680 |
| - "`axes` should not be `None` if `s` is not `None` " |
681 |
| - "(Deprecated in NumPy 2.0). In a future version of NumPy, " |
682 |
| - "this will raise an error and `s[i]` will correspond to " |
683 |
| - "the size along the transformed axis specified by " |
684 |
| - "`axes[i]`. To retain current behaviour, pass a sequence " |
685 |
| - "[0, ..., k-1] to `axes` for an array of dimension k." |
686 |
| - ) |
687 |
| - warnings.warn(msg, DeprecationWarning, stacklevel=3) |
688 |
| - axes = list(range(-len(s), 0)) |
689 |
| - if len(s) != len(axes): |
690 |
| - raise ValueError("Shape and axes have different lengths.") |
691 |
| - if invreal and shapeless: |
692 |
| - s[-1] = (a.shape[axes[-1]] - 1) * 2 |
693 |
| - if None in s and np.__version__ >= "2.0": |
694 |
| - msg = ( |
695 |
| - "Passing an array containing `None` values to `s` is " |
696 |
| - "deprecated in NumPy 2.0 and will raise an error in " |
697 |
| - "a future version of NumPy. To use the default behaviour " |
698 |
| - "of the corresponding 1-D transform, pass the value matching " |
699 |
| - "the default for its `n` parameter. To use the default " |
700 |
| - "behaviour for every axis, the `s` argument can be omitted." |
701 |
| - ) |
702 |
| - warnings.warn(msg, DeprecationWarning, stacklevel=3) |
703 |
| - # use the whole input array along axis `i` if `s[i] == -1 or None` |
704 |
| - s = [a.shape[_a] if _s in [-1, None] else _s for _s, _a in zip(s, axes)] |
705 |
| - |
706 |
| - return s, axes |
707 |
| - |
708 |
| - |
709 | 687 | def fftn(a, s=None, axes=None, norm=None):
|
710 | 688 | """
|
711 | 689 | Compute the N-dimensional discrete Fourier Transform.
|
@@ -806,7 +784,7 @@ def fftn(a, s=None, axes=None, norm=None):
|
806 | 784 |
|
807 | 785 | """
|
808 | 786 |
|
809 |
| - x = _float_utils.__downcast_float128_array(a) |
| 787 | + x = __downcast_float128_array(a) |
810 | 788 | s, axes = _cook_nd_args(x, s, axes)
|
811 | 789 | fsc = _compute_fwd_scale(norm, s, x.shape)
|
812 | 790 |
|
@@ -913,7 +891,7 @@ def ifftn(a, s=None, axes=None, norm=None):
|
913 | 891 |
|
914 | 892 | """
|
915 | 893 |
|
916 |
| - x = _float_utils.__downcast_float128_array(a) |
| 894 | + x = __downcast_float128_array(a) |
917 | 895 | s, axes = _cook_nd_args(x, s, axes)
|
918 | 896 | fsc = _compute_fwd_scale(norm, s, x.shape)
|
919 | 897 |
|
@@ -1201,7 +1179,7 @@ def rfftn(a, s=None, axes=None, norm=None):
|
1201 | 1179 |
|
1202 | 1180 | """
|
1203 | 1181 |
|
1204 |
| - x = _float_utils.__downcast_float128_array(a) |
| 1182 | + x = __downcast_float128_array(a) |
1205 | 1183 | s, axes = _cook_nd_args(x, s, axes)
|
1206 | 1184 | fsc = _compute_fwd_scale(norm, s, x.shape)
|
1207 | 1185 |
|
@@ -1345,7 +1323,7 @@ def irfftn(a, s=None, axes=None, norm=None):
|
1345 | 1323 |
|
1346 | 1324 | """
|
1347 | 1325 |
|
1348 |
| - x = _float_utils.__downcast_float128_array(a) |
| 1326 | + x = __downcast_float128_array(a) |
1349 | 1327 | s, axes = _cook_nd_args(x, s, axes, invreal=True)
|
1350 | 1328 | fsc = _compute_fwd_scale(norm, s, x.shape)
|
1351 | 1329 |
|
|
0 commit comments