diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d87aadd1824..f3e49a10e25a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Improved documentations of `dpnp.ndarray` class and added a page with description of supported constants [#2422](https://github.com/IntelPython/dpnp/pull/2422) * Updated `dpnp.size` to accept tuple of ints for `axes` argument [#2536](https://github.com/IntelPython/dpnp/pull/2536) * Replaced `ci` section in `.pre-commit-config.yaml` with a new GitHub workflow with scheduled run to autoupdate the `pre-commit` configuration [#2542](https://github.com/IntelPython/dpnp/pull/2542) +* FFT module is updated to perform in-place FFT in intermediate steps of ND FFT [#2543](https://github.com/IntelPython/dpnp/pull/2543) ### Deprecated diff --git a/dpnp/fft/dpnp_utils_fft.py b/dpnp/fft/dpnp_utils_fft.py index fa6e49249fd4..2dcefaaeb757 100644 --- a/dpnp/fft/dpnp_utils_fft.py +++ b/dpnp/fft/dpnp_utils_fft.py @@ -110,14 +110,13 @@ def _commit_descriptor(a, forward, in_place, c2c, a_strides, index, batch_fft): return dsc, out_strides -def _complex_nd_fft( +def _c2c_nd_fft( a, s, norm, out, forward, in_place, - c2c, axes, batch_fft, *, @@ -126,34 +125,38 @@ def _complex_nd_fft( """Computes complex-to-complex FFT of the input N-D array.""" len_axes = len(axes) - # OneMKL supports up to 3-dimensional FFT on GPU - # repeated axis in OneMKL FFT is not allowed + # oneMKL supports up to 3-dimensional FFT on GPU + # repeated axis in oneMKL FFT is not allowed if len_axes > 3 or len(set(axes)) < len_axes: axes_chunk, shape_chunk = _extract_axes_chunk( axes, s, chunk_size=3, reversed_axes=reversed_axes ) + + # We try to use in-place calculations where possible, which is + # everywhere except when the size changes after the first iteration. + size_changes = [axis for axis, n in zip(axes, s) if a.shape[axis] != n] + + # cannot use out in the intermediate steps if size changes + res = None if size_changes else out + for i, (s_chunk, a_chunk) in enumerate(zip(shape_chunk, axes_chunk)): a = _truncate_or_pad(a, shape=s_chunk, axes=a_chunk) - # if out is used in an intermediate step, it will have memory - # overlap with input and cannot be used in the final step (a new - # result array will be created for the final step), so there is no - # benefit in using out in an intermediate step - if i == len(axes_chunk) - 1: - tmp_out = out - else: - tmp_out = None + # if size_changes, out cannot be used in intermediate steps + if size_changes and i == len(axes_chunk) - 1: + res = out a = _fft( a, norm=norm, - out=tmp_out, + out=res, forward=forward, - # TODO: in-place FFT is only implemented for c2c, see SAT-7154 - in_place=in_place and c2c, - c2c=c2c, + in_place=in_place, + c2c=True, axes=a_chunk, ) - + if not size_changes: + # Default output for next iteration. + res = a return a a = _truncate_or_pad(a, s, axes) @@ -165,9 +168,8 @@ def _complex_nd_fft( norm=norm, out=out, forward=forward, - # TODO: in-place FFT is only implemented for c2c, see SAT-7154 - in_place=in_place and c2c, - c2c=c2c, + in_place=in_place, + c2c=True, axes=axes, batch_fft=batch_fft, ) @@ -198,7 +200,7 @@ def _compute_result(dsc, a, out, forward, c2c, out_strides): res_usm = dpnp.get_usm_ndarray(out) result = out else: - # Result array that is used in OneMKL must have the exact same + # Result array that is used in oneMKL must have the exact same # stride as input array if c2c: # c2c FFT @@ -277,9 +279,9 @@ def _copy_array(x, complex_input): dtype = x.dtype copy_flag = False if numpy.min(x.strides) < 0: - # negative stride is not allowed in OneMKL FFT + # negative stride is not allowed in oneMKL FFT # TODO: support for negative strides will be added in the future - # versions of OneMKL, see discussion in MKLD-17597 + # versions of oneMKL, see discussion in MKLD-17597 copy_flag = True if complex_input and not dpnp.issubdtype(dtype, dpnp.complexfloating): @@ -388,6 +390,9 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True): index = 0 fft_1d = isinstance(axes, int) + if not in_place and out is not None: + # if input and output are the same array, use in-place FFT + in_place = dpnp.are_same_logical_tensors(a, out) if batch_fft: len_axes = 1 if fft_1d else len(axes) local_axes = numpy.arange(-len_axes, 0) @@ -627,9 +632,6 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None): _validate_out_keyword(a, out, (n,), (axis,), c2c, c2r, r2c) # if input array is copied, in-place FFT can be used a, in_place = _copy_array(a, c2c or c2r) - if not in_place and out is not None: - # if input is also given for out, in-place FFT can be used - in_place = dpnp.are_same_logical_tensors(a, out) if a.size == 0: return dpnp.get_result_array(a, out=out, casting="same_kind") @@ -695,31 +697,30 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None): ) if r2c: - # a 1D real-to-complext FFT is performed on the last axis and then + size_changes = [axis for axis, n in zip(axes, s) if a.shape[axis] != n] + # cannot use out in the intermediate steps if size changes + res = None if size_changes else out + + # a 1D real-to-complex FFT is performed on the last axis and then # an N-D complex-to-complex FFT over the remaining axes a = _truncate_or_pad(a, (s[-1],), (axes[-1],)) a = _fft( a, norm=norm, - # if out is used in an intermediate step, it will have memory - # overlap with input and cannot be used in the final step (a new - # result array will be created for the final step), so there is no - # benefit in using out in an intermediate step - out=None, + out=res, forward=forward, - in_place=in_place and c2c, - c2c=c2c, + in_place=False, + c2c=False, axes=axes[-1], batch_fft=a.ndim != 1, ) - return _complex_nd_fft( + return _c2c_nd_fft( a, - s=s, + s=s[:-1], norm=norm, out=out, forward=forward, in_place=in_place, - c2c=True, axes=axes[:-1], batch_fft=a.ndim != len_axes - 1, ) @@ -727,31 +728,43 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None): if c2r: # an N-D complex-to-complex FFT is performed on all axes except the # last one then a 1D complex-to-real FFT is performed on the last axis - a = _complex_nd_fft( + a = _c2c_nd_fft( a, - s=s, + s=s[:-1], norm=norm, # out has real dtype and cannot be used in intermediate steps out=None, forward=forward, in_place=in_place, - c2c=True, axes=axes[:-1], batch_fft=a.ndim != len_axes - 1, reversed_axes=False, ) a = _truncate_or_pad(a, (s[-1],), (axes[-1],)) - if c2r: - a = _make_array_hermitian( - a, axes[-1], dpnp.are_same_logical_tensors(a, a_orig) - ) + a = _make_array_hermitian( + a, axes[-1], dpnp.are_same_logical_tensors(a, a_orig) + ) return _fft( - a, norm, out, forward, in_place and c2c, c2c, axes[-1], a.ndim != 1 + a, + norm=norm, + out=out, + forward=forward, + in_place=False, + c2c=False, + axes=axes[-1], + batch_fft=a.ndim != 1, ) # c2c - return _complex_nd_fft( - a, s, norm, out, forward, in_place, c2c, axes, a.ndim != len_axes + return _c2c_nd_fft( + a, + s=s, + norm=norm, + out=out, + forward=forward, + in_place=in_place, + axes=axes, + batch_fft=a.ndim != len_axes, )