Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Improved documentations of `dpnp.ndarray` class and added a page with description of supported constants [#2422](https://github.com/IntelPython/dpnp/pull/2422)
* Updated `dpnp.size` to accept tuple of ints for `axes` argument [#2536](https://github.com/IntelPython/dpnp/pull/2536)
* Replaced `ci` section in `.pre-commit-config.yaml` with a new GitHub workflow with scheduled run to autoupdate the `pre-commit` configuration [#2542](https://github.com/IntelPython/dpnp/pull/2542)
* FFT module is updated to perform in-place FFT in intermediate steps of ND FFT [#2543](https://github.com/IntelPython/dpnp/pull/2543)

### Deprecated

Expand Down
107 changes: 60 additions & 47 deletions dpnp/fft/dpnp_utils_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,13 @@ def _commit_descriptor(a, forward, in_place, c2c, a_strides, index, batch_fft):
return dsc, out_strides


def _complex_nd_fft(
def _c2c_nd_fft(
a,
s,
norm,
out,
forward,
in_place,
c2c,
axes,
batch_fft,
*,
Expand All @@ -126,34 +125,38 @@ def _complex_nd_fft(
"""Computes complex-to-complex FFT of the input N-D array."""

len_axes = len(axes)
# OneMKL supports up to 3-dimensional FFT on GPU
# repeated axis in OneMKL FFT is not allowed
# oneMKL supports up to 3-dimensional FFT on GPU
# repeated axis in oneMKL FFT is not allowed
if len_axes > 3 or len(set(axes)) < len_axes:
axes_chunk, shape_chunk = _extract_axes_chunk(
axes, s, chunk_size=3, reversed_axes=reversed_axes
)

# We try to use in-place calculations where possible, which is
# everywhere except when the size changes after the first iteration.
size_changes = [axis for axis, n in zip(axes, s) if a.shape[axis] != n]

# cannot use out in the intermediate steps if size changes
res = None if size_changes else out

for i, (s_chunk, a_chunk) in enumerate(zip(shape_chunk, axes_chunk)):
a = _truncate_or_pad(a, shape=s_chunk, axes=a_chunk)
# if out is used in an intermediate step, it will have memory
# overlap with input and cannot be used in the final step (a new
# result array will be created for the final step), so there is no
# benefit in using out in an intermediate step
if i == len(axes_chunk) - 1:
tmp_out = out
else:
tmp_out = None
# if size_changes, out cannot be used in intermediate steps
if size_changes and i == len(axes_chunk) - 1:
res = out

a = _fft(
a,
norm=norm,
out=tmp_out,
out=res,
forward=forward,
# TODO: in-place FFT is only implemented for c2c, see SAT-7154
in_place=in_place and c2c,
c2c=c2c,
in_place=in_place,
c2c=True,
axes=a_chunk,
)

if not size_changes:
# Default output for next iteration.
res = a
return a

a = _truncate_or_pad(a, s, axes)
Expand All @@ -165,9 +168,8 @@ def _complex_nd_fft(
norm=norm,
out=out,
forward=forward,
# TODO: in-place FFT is only implemented for c2c, see SAT-7154
in_place=in_place and c2c,
c2c=c2c,
in_place=in_place,
c2c=True,
axes=axes,
batch_fft=batch_fft,
)
Expand Down Expand Up @@ -198,7 +200,7 @@ def _compute_result(dsc, a, out, forward, c2c, out_strides):
res_usm = dpnp.get_usm_ndarray(out)
result = out
else:
# Result array that is used in OneMKL must have the exact same
# Result array that is used in oneMKL must have the exact same
# stride as input array

if c2c: # c2c FFT
Expand Down Expand Up @@ -277,9 +279,9 @@ def _copy_array(x, complex_input):
dtype = x.dtype
copy_flag = False
if numpy.min(x.strides) < 0:
# negative stride is not allowed in OneMKL FFT
# negative stride is not allowed in oneMKL FFT
# TODO: support for negative strides will be added in the future
# versions of OneMKL, see discussion in MKLD-17597
# versions of oneMKL, see discussion in MKLD-17597
copy_flag = True

if complex_input and not dpnp.issubdtype(dtype, dpnp.complexfloating):
Expand Down Expand Up @@ -388,6 +390,9 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):

index = 0
fft_1d = isinstance(axes, int)
if not in_place and out is not None:
# if input and output are the same array, use in-place FFT
in_place = dpnp.are_same_logical_tensors(a, out)
if batch_fft:
len_axes = 1 if fft_1d else len(axes)
local_axes = numpy.arange(-len_axes, 0)
Expand Down Expand Up @@ -627,9 +632,6 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
_validate_out_keyword(a, out, (n,), (axis,), c2c, c2r, r2c)
# if input array is copied, in-place FFT can be used
a, in_place = _copy_array(a, c2c or c2r)
if not in_place and out is not None:
# if input is also given for out, in-place FFT can be used
in_place = dpnp.are_same_logical_tensors(a, out)

if a.size == 0:
return dpnp.get_result_array(a, out=out, casting="same_kind")
Expand Down Expand Up @@ -695,63 +697,74 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
)

if r2c:
# a 1D real-to-complext FFT is performed on the last axis and then
size_changes = [axis for axis, n in zip(axes, s) if a.shape[axis] != n]
# cannot use out in the intermediate steps if size changes
res = None if size_changes else out

# a 1D real-to-complex FFT is performed on the last axis and then
# an N-D complex-to-complex FFT over the remaining axes
a = _truncate_or_pad(a, (s[-1],), (axes[-1],))
a = _fft(
a,
norm=norm,
# if out is used in an intermediate step, it will have memory
# overlap with input and cannot be used in the final step (a new
# result array will be created for the final step), so there is no
# benefit in using out in an intermediate step
out=None,
out=res,
forward=forward,
in_place=in_place and c2c,
c2c=c2c,
in_place=False,
c2c=False,
axes=axes[-1],
batch_fft=a.ndim != 1,
)
return _complex_nd_fft(
return _c2c_nd_fft(
a,
s=s,
s=s[:-1],
norm=norm,
out=out,
forward=forward,
in_place=in_place,
c2c=True,
axes=axes[:-1],
batch_fft=a.ndim != len_axes - 1,
)

if c2r:
# an N-D complex-to-complex FFT is performed on all axes except the
# last one then a 1D complex-to-real FFT is performed on the last axis
a = _complex_nd_fft(
a = _c2c_nd_fft(
a,
s=s,
s=s[:-1],
norm=norm,
# out has real dtype and cannot be used in intermediate steps
out=None,
forward=forward,
in_place=in_place,
c2c=True,
axes=axes[:-1],
batch_fft=a.ndim != len_axes - 1,
reversed_axes=False,
)
a = _truncate_or_pad(a, (s[-1],), (axes[-1],))
if c2r:
a = _make_array_hermitian(
a, axes[-1], dpnp.are_same_logical_tensors(a, a_orig)
)
a = _make_array_hermitian(
a, axes[-1], dpnp.are_same_logical_tensors(a, a_orig)
)
return _fft(
a, norm, out, forward, in_place and c2c, c2c, axes[-1], a.ndim != 1
a,
norm=norm,
out=out,
forward=forward,
in_place=False,
c2c=False,
axes=axes[-1],
batch_fft=a.ndim != 1,
)

# c2c
return _complex_nd_fft(
a, s, norm, out, forward, in_place, c2c, axes, a.ndim != len_axes
return _c2c_nd_fft(
a,
s=s,
norm=norm,
out=out,
forward=forward,
in_place=in_place,
axes=axes,
batch_fft=a.ndim != len_axes,
)


Expand Down
Loading