Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 3 additions & 25 deletions dpnp/fft/dpnp_iface_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@

import dpnp

from .dpnp_utils_fft import (
dpnp_fft,
dpnp_fftn,
dpnp_fillfreq,
)
from .dpnp_utils_fft import dpnp_fft, dpnp_fftn, dpnp_fillfreq, swap_direction

__all__ = [
"fft",
Expand All @@ -66,24 +62,6 @@
]


_SWAP_DIRECTION_MAP = {
"backward": "forward",
None: "forward",
"ortho": "ortho",
"forward": "backward",
}


def _swap_direction(norm):
try:
return _SWAP_DIRECTION_MAP[norm]
except KeyError:
raise ValueError(
f'Invalid norm value {norm}; should be None, "backward", '
'"ortho" or "forward".'
) from None


def fft(a, n=None, axis=-1, norm=None, out=None):
"""
Compute the one-dimensional discrete Fourier Transform.
Expand Down Expand Up @@ -644,7 +622,7 @@ def hfft(a, n=None, axis=-1, norm=None, out=None):

"""

new_norm = _swap_direction(norm)
new_norm = swap_direction(norm)
return irfft(dpnp.conjugate(a), n=n, axis=axis, norm=new_norm, out=out)


Expand Down Expand Up @@ -1073,7 +1051,7 @@ def ihfft(a, n=None, axis=-1, norm=None, out=None):

"""

new_norm = _swap_direction(norm)
new_norm = swap_direction(norm)
res = rfft(a, n=n, axis=axis, norm=new_norm, out=out)
return dpnp.conjugate(res, out=out)

Expand Down
21 changes: 15 additions & 6 deletions dpnp/fft/dpnp_utils_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,7 @@
_standardize_strides_to_nonzero,
)

__all__ = [
"dpnp_fft",
"dpnp_fftn",
"dpnp_fillfreq",
]
__all__ = ["dpnp_fft", "dpnp_fftn", "dpnp_fillfreq", "swap_direction"]


def _check_norm(norm):
Expand Down Expand Up @@ -584,7 +580,6 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
if n < 1:
raise ValueError(f"Invalid number of FFT data points ({n}) specified")

_check_norm(norm)
a = _truncate_or_pad(a, (n,), (axis,))
_validate_out_keyword(a, out, (n,), (axis,), c2c, c2r, r2c)
# if input array is copied, in-place FFT can be used
Expand Down Expand Up @@ -714,3 +709,17 @@ def dpnp_fillfreq(a, m, n, val):
ht_lin_ev, lin_ev = ti._linspace_step(m - n, 1, a[m:].get_array(), exec_q)
_manager.add_event_pair(ht_lin_ev, lin_ev)
return a * val


def swap_direction(norm):
"""Swap the direction of the FFT."""

_check_norm(norm)
_swap_direction_map = {
"backward": "forward",
None: "forward",
"ortho": "ortho",
"forward": "backward",
}

return _swap_direction_map[norm]
11 changes: 11 additions & 0 deletions dpnp/tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_all_dtypes,
get_complex_dtypes,
get_float_dtypes,
has_support_aspect16,
)


Expand Down Expand Up @@ -926,6 +927,16 @@ def test_rfft_1D_on_2D_array_out(self, dtype, n, axis, norm, order):
expected = numpy.fft.rfft(a_np, n=n, axis=axis, norm=norm)
assert_dtype_allclose(result, expected, check_only_type_kind=True)

@pytest.mark.skipif(not has_support_aspect16(), reason="no fp16 support")
def test_float16(self):
a = numpy.ones(10, dtype=numpy.float16)
ia = dpnp.array(a)

result = numpy.fft.rfft(a)
expected = dpnp.fft.rfft(ia)
# check_only_type_kind=True since Intel NumPy returns complex128
assert_dtype_allclose(expected, result, check_only_type_kind=True)

@pytest.mark.parametrize("xp", [numpy, dpnp])
def test_rfft_error(self, xp):
a = xp.ones((4, 3), dtype=xp.complex64)
Expand Down
1 change: 0 additions & 1 deletion dpnp/tests/test_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
get_all_dtypes,
get_float_complex_dtypes,
get_float_dtypes,
has_support_aspect16,
has_support_aspect64,
)

Expand Down
Loading