diff --git a/dpnp/fft/dpnp_iface_fft.py b/dpnp/fft/dpnp_iface_fft.py index 52753e5362e0..c9225f7c8a16 100644 --- a/dpnp/fft/dpnp_iface_fft.py +++ b/dpnp/fft/dpnp_iface_fft.py @@ -38,11 +38,7 @@ import dpnp -from .dpnp_utils_fft import ( - dpnp_fft, - dpnp_fftn, - dpnp_fillfreq, -) +from .dpnp_utils_fft import dpnp_fft, dpnp_fftn, dpnp_fillfreq, swap_direction __all__ = [ "fft", @@ -66,24 +62,6 @@ ] -_SWAP_DIRECTION_MAP = { - "backward": "forward", - None: "forward", - "ortho": "ortho", - "forward": "backward", -} - - -def _swap_direction(norm): - try: - return _SWAP_DIRECTION_MAP[norm] - except KeyError: - raise ValueError( - f'Invalid norm value {norm}; should be None, "backward", ' - '"ortho" or "forward".' - ) from None - - def fft(a, n=None, axis=-1, norm=None, out=None): """ Compute the one-dimensional discrete Fourier Transform. @@ -644,7 +622,7 @@ def hfft(a, n=None, axis=-1, norm=None, out=None): """ - new_norm = _swap_direction(norm) + new_norm = swap_direction(norm) return irfft(dpnp.conjugate(a), n=n, axis=axis, norm=new_norm, out=out) @@ -1073,7 +1051,7 @@ def ihfft(a, n=None, axis=-1, norm=None, out=None): """ - new_norm = _swap_direction(norm) + new_norm = swap_direction(norm) res = rfft(a, n=n, axis=axis, norm=new_norm, out=out) return dpnp.conjugate(res, out=out) diff --git a/dpnp/fft/dpnp_utils_fft.py b/dpnp/fft/dpnp_utils_fft.py index 1393d22255e7..25bfaa1676dd 100644 --- a/dpnp/fft/dpnp_utils_fft.py +++ b/dpnp/fft/dpnp_utils_fft.py @@ -57,11 +57,7 @@ _standardize_strides_to_nonzero, ) -__all__ = [ - "dpnp_fft", - "dpnp_fftn", - "dpnp_fillfreq", -] +__all__ = ["dpnp_fft", "dpnp_fftn", "dpnp_fillfreq", "swap_direction"] def _check_norm(norm): @@ -584,7 +580,6 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None): if n < 1: raise ValueError(f"Invalid number of FFT data points ({n}) specified") - _check_norm(norm) a = _truncate_or_pad(a, (n,), (axis,)) _validate_out_keyword(a, out, (n,), (axis,), c2c, c2r, r2c) # if input array is copied, in-place FFT can be used @@ -714,3 +709,17 @@ def dpnp_fillfreq(a, m, n, val): ht_lin_ev, lin_ev = ti._linspace_step(m - n, 1, a[m:].get_array(), exec_q) _manager.add_event_pair(ht_lin_ev, lin_ev) return a * val + + +def swap_direction(norm): + """Swap the direction of the FFT.""" + + _check_norm(norm) + _swap_direction_map = { + "backward": "forward", + None: "forward", + "ortho": "ortho", + "forward": "backward", + } + + return _swap_direction_map[norm] diff --git a/dpnp/tests/test_fft.py b/dpnp/tests/test_fft.py index 8a65d2857a77..7d822834b67d 100644 --- a/dpnp/tests/test_fft.py +++ b/dpnp/tests/test_fft.py @@ -14,6 +14,7 @@ get_all_dtypes, get_complex_dtypes, get_float_dtypes, + has_support_aspect16, ) @@ -926,6 +927,16 @@ def test_rfft_1D_on_2D_array_out(self, dtype, n, axis, norm, order): expected = numpy.fft.rfft(a_np, n=n, axis=axis, norm=norm) assert_dtype_allclose(result, expected, check_only_type_kind=True) + @pytest.mark.skipif(not has_support_aspect16(), reason="no fp16 support") + def test_float16(self): + a = numpy.arange(10, dtype=numpy.float16) + ia = dpnp.array(a) + + expected = numpy.fft.rfft(a) + result = dpnp.fft.rfft(ia) + # check_only_type_kind=True since Intel NumPy returns complex128 + assert_dtype_allclose(result, expected, check_only_type_kind=True) + @pytest.mark.parametrize("xp", [numpy, dpnp]) def test_rfft_error(self, xp): a = xp.ones((4, 3), dtype=xp.complex64)