Skip to content

Commit 05b820d

Browse files
authored
update a utility function in FFT module (#2393)
In this PR, a utility function in FFT module is updated to remove duplication. In addition a new test is added to make coverage for FFT module 100%.
1 parent a567221 commit 05b820d

File tree

3 files changed

+29
-31
lines changed

3 files changed

+29
-31
lines changed

dpnp/fft/dpnp_iface_fft.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@
3838

3939
import dpnp
4040

41-
from .dpnp_utils_fft import (
42-
dpnp_fft,
43-
dpnp_fftn,
44-
dpnp_fillfreq,
45-
)
41+
from .dpnp_utils_fft import dpnp_fft, dpnp_fftn, dpnp_fillfreq, swap_direction
4642

4743
__all__ = [
4844
"fft",
@@ -66,24 +62,6 @@
6662
]
6763

6864

69-
_SWAP_DIRECTION_MAP = {
70-
"backward": "forward",
71-
None: "forward",
72-
"ortho": "ortho",
73-
"forward": "backward",
74-
}
75-
76-
77-
def _swap_direction(norm):
78-
try:
79-
return _SWAP_DIRECTION_MAP[norm]
80-
except KeyError:
81-
raise ValueError(
82-
f'Invalid norm value {norm}; should be None, "backward", '
83-
'"ortho" or "forward".'
84-
) from None
85-
86-
8765
def fft(a, n=None, axis=-1, norm=None, out=None):
8866
"""
8967
Compute the one-dimensional discrete Fourier Transform.
@@ -644,7 +622,7 @@ def hfft(a, n=None, axis=-1, norm=None, out=None):
644622
645623
"""
646624

647-
new_norm = _swap_direction(norm)
625+
new_norm = swap_direction(norm)
648626
return irfft(dpnp.conjugate(a), n=n, axis=axis, norm=new_norm, out=out)
649627

650628

@@ -1073,7 +1051,7 @@ def ihfft(a, n=None, axis=-1, norm=None, out=None):
10731051
10741052
"""
10751053

1076-
new_norm = _swap_direction(norm)
1054+
new_norm = swap_direction(norm)
10771055
res = rfft(a, n=n, axis=axis, norm=new_norm, out=out)
10781056
return dpnp.conjugate(res, out=out)
10791057

dpnp/fft/dpnp_utils_fft.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,7 @@
5757
_standardize_strides_to_nonzero,
5858
)
5959

60-
__all__ = [
61-
"dpnp_fft",
62-
"dpnp_fftn",
63-
"dpnp_fillfreq",
64-
]
60+
__all__ = ["dpnp_fft", "dpnp_fftn", "dpnp_fillfreq", "swap_direction"]
6561

6662

6763
def _check_norm(norm):
@@ -584,7 +580,6 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
584580
if n < 1:
585581
raise ValueError(f"Invalid number of FFT data points ({n}) specified")
586582

587-
_check_norm(norm)
588583
a = _truncate_or_pad(a, (n,), (axis,))
589584
_validate_out_keyword(a, out, (n,), (axis,), c2c, c2r, r2c)
590585
# if input array is copied, in-place FFT can be used
@@ -714,3 +709,17 @@ def dpnp_fillfreq(a, m, n, val):
714709
ht_lin_ev, lin_ev = ti._linspace_step(m - n, 1, a[m:].get_array(), exec_q)
715710
_manager.add_event_pair(ht_lin_ev, lin_ev)
716711
return a * val
712+
713+
714+
def swap_direction(norm):
715+
"""Swap the direction of the FFT."""
716+
717+
_check_norm(norm)
718+
_swap_direction_map = {
719+
"backward": "forward",
720+
None: "forward",
721+
"ortho": "ortho",
722+
"forward": "backward",
723+
}
724+
725+
return _swap_direction_map[norm]

dpnp/tests/test_fft.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_all_dtypes,
1515
get_complex_dtypes,
1616
get_float_dtypes,
17+
has_support_aspect16,
1718
)
1819

1920

@@ -926,6 +927,16 @@ def test_rfft_1D_on_2D_array_out(self, dtype, n, axis, norm, order):
926927
expected = numpy.fft.rfft(a_np, n=n, axis=axis, norm=norm)
927928
assert_dtype_allclose(result, expected, check_only_type_kind=True)
928929

930+
@pytest.mark.skipif(not has_support_aspect16(), reason="no fp16 support")
931+
def test_float16(self):
932+
a = numpy.arange(10, dtype=numpy.float16)
933+
ia = dpnp.array(a)
934+
935+
expected = numpy.fft.rfft(a)
936+
result = dpnp.fft.rfft(ia)
937+
# check_only_type_kind=True since Intel NumPy returns complex128
938+
assert_dtype_allclose(result, expected, check_only_type_kind=True)
939+
929940
@pytest.mark.parametrize("xp", [numpy, dpnp])
930941
def test_rfft_error(self, xp):
931942
a = xp.ones((4, 3), dtype=xp.complex64)

0 commit comments

Comments
 (0)