Skip to content

Commit a4a76ca

Browse files
author
Vahid Tavanashad
committed
update norm FFT
1 parent eb8ebf8 commit a4a76ca

File tree

4 files changed

+27
-32
lines changed

4 files changed

+27
-32
lines changed

dpnp/fft/dpnp_iface_fft.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@
3838

3939
import dpnp
4040

41-
from .dpnp_utils_fft import (
42-
dpnp_fft,
43-
dpnp_fftn,
44-
dpnp_fillfreq,
45-
)
41+
from .dpnp_utils_fft import dpnp_fft, dpnp_fftn, dpnp_fillfreq, swap_direction
4642

4743
__all__ = [
4844
"fft",
@@ -66,24 +62,6 @@
6662
]
6763

6864

69-
_SWAP_DIRECTION_MAP = {
70-
"backward": "forward",
71-
None: "forward",
72-
"ortho": "ortho",
73-
"forward": "backward",
74-
}
75-
76-
77-
def _swap_direction(norm):
78-
try:
79-
return _SWAP_DIRECTION_MAP[norm]
80-
except KeyError:
81-
raise ValueError(
82-
f'Invalid norm value {norm}; should be None, "backward", '
83-
'"ortho" or "forward".'
84-
) from None
85-
86-
8765
def fft(a, n=None, axis=-1, norm=None, out=None):
8866
"""
8967
Compute the one-dimensional discrete Fourier Transform.
@@ -644,7 +622,7 @@ def hfft(a, n=None, axis=-1, norm=None, out=None):
644622
645623
"""
646624

647-
new_norm = _swap_direction(norm)
625+
new_norm = swap_direction(norm)
648626
return irfft(dpnp.conjugate(a), n=n, axis=axis, norm=new_norm, out=out)
649627

650628

@@ -1073,7 +1051,7 @@ def ihfft(a, n=None, axis=-1, norm=None, out=None):
10731051
10741052
"""
10751053

1076-
new_norm = _swap_direction(norm)
1054+
new_norm = swap_direction(norm)
10771055
res = rfft(a, n=n, axis=axis, norm=new_norm, out=out)
10781056
return dpnp.conjugate(res, out=out)
10791057

dpnp/fft/dpnp_utils_fft.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,7 @@
5757
_standardize_strides_to_nonzero,
5858
)
5959

60-
__all__ = [
61-
"dpnp_fft",
62-
"dpnp_fftn",
63-
"dpnp_fillfreq",
64-
]
60+
__all__ = ["dpnp_fft", "dpnp_fftn", "dpnp_fillfreq", "swap_direction"]
6561

6662

6763
def _check_norm(norm):
@@ -584,7 +580,6 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
584580
if n < 1:
585581
raise ValueError(f"Invalid number of FFT data points ({n}) specified")
586582

587-
_check_norm(norm)
588583
a = _truncate_or_pad(a, (n,), (axis,))
589584
_validate_out_keyword(a, out, (n,), (axis,), c2c, c2r, r2c)
590585
# if input array is copied, in-place FFT can be used
@@ -714,3 +709,17 @@ def dpnp_fillfreq(a, m, n, val):
714709
ht_lin_ev, lin_ev = ti._linspace_step(m - n, 1, a[m:].get_array(), exec_q)
715710
_manager.add_event_pair(ht_lin_ev, lin_ev)
716711
return a * val
712+
713+
714+
def swap_direction(norm):
715+
"""Swap the direction of the FFT."""
716+
717+
_check_norm(norm)
718+
_swap_direction_map = {
719+
"backward": "forward",
720+
None: "forward",
721+
"ortho": "ortho",
722+
"forward": "backward",
723+
}
724+
725+
return _swap_direction_map[norm]

dpnp/tests/test_fft.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,15 @@ def test_rfft_1D_on_2D_array_out(self, dtype, n, axis, norm, order):
926926
expected = numpy.fft.rfft(a_np, n=n, axis=axis, norm=norm)
927927
assert_dtype_allclose(result, expected, check_only_type_kind=True)
928928

929+
@pytest.mark.skipif(not has_support_aspect16(), reason="no fp16 support")
930+
def test_float16(self):
931+
a = numpy.ones(10, dtype=numpy.float16)
932+
ia = dpnp.array(a)
933+
934+
result = numpy.fft.rfft(a)
935+
expected = dpnp.fft.rfft(ia)
936+
assert_dtype_allclose(expected, result)
937+
929938
@pytest.mark.parametrize("xp", [numpy, dpnp])
930939
def test_rfft_error(self, xp):
931940
a = xp.ones((4, 3), dtype=xp.complex64)

dpnp/tests/test_umath.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
get_all_dtypes,
1919
get_float_complex_dtypes,
2020
get_float_dtypes,
21-
has_support_aspect16,
2221
has_support_aspect64,
2322
)
2423

0 commit comments

Comments
 (0)