Skip to content

Commit 65e33c9

Browse files
vtavanaantonwolfy
andauthored
implement dpnp.fft.fftshift and dpnp.fft.ifftshift (#1900)
* implement fftshift * remove old implementation * fix alphabetic order * add CFD tests * Apply suggestions from code review Co-authored-by: Anton <[email protected]> --------- Co-authored-by: Anton <[email protected]>
1 parent cdfadaf commit 65e33c9

File tree

5 files changed

+136
-70
lines changed

5 files changed

+136
-70
lines changed

dpnp/fft/dpnp_iface_fft.py

Lines changed: 90 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -366,45 +366,62 @@ def fftshift(x, axes=None):
366366
"""
367367
Shift the zero-frequency component to the center of the spectrum.
368368
369+
This function swaps half-spaces for all axes listed (defaults to all).
370+
Note that ``out[0]`` is the Nyquist component only if ``len(x)`` is even.
371+
369372
For full documentation refer to :obj:`numpy.fft.fftshift`.
370373
371-
Limitations
372-
-----------
373-
Parameter `x` is supported either as :class:`dpnp.ndarray`.
374-
Parameter `axes` is unsupported.
375-
Only `dpnp.float64`, `dpnp.float32`, `dpnp.int64`, `dpnp.int32`,
376-
`dpnp.complex128` data types are supported.
377-
Otherwise the function will be executed sequentially on CPU.
374+
Parameters
375+
----------
376+
x : {dpnp.ndarray, usm_ndarray}
377+
Input array.
378+
axes : {None, int, list or tuple of ints}, optional
379+
Axes over which to shift.
380+
Default is ``None``, which shifts all axes.
378381
379-
"""
382+
Returns
383+
-------
384+
out : dpnp.ndarray
385+
The shifted array.
380386
381-
x_desc = dpnp.get_dpnp_descriptor(x, copy_when_nondefault_queue=False)
382-
# TODO: enable implementation
383-
# pylint: disable=condition-evals-to-constant
384-
if x_desc and 0:
385-
norm_ = Norm.backward
387+
See Also
388+
--------
389+
:obj:`dpnp.fft.ifftshift` : The inverse of :obj:`dpnp.fft.fftshift`.
386390
387-
if axes is None:
388-
axis_param = -1 # the most right dimension (default value)
389-
else:
390-
axis_param = axes
391+
Examples
392+
--------
393+
>>> import dpnp as np
394+
>>> freqs = np.fft.fftfreq(10, 0.1)
395+
>>> freqs
396+
array([ 0., 1., 2., 3., 4., -5., -4., -3., -2., -1.])
397+
>>> np.fft.fftshift(freqs)
398+
array([-5., -4., -3., -2., -1., 0., 1., 2., 3., 4.])
399+
400+
Shift the zero-frequency component only along the second axis:
401+
402+
>>> freqs = np.fft.fftfreq(9, d=1./9).reshape(3, 3)
403+
>>> freqs
404+
array([[ 0., 1., 2.],
405+
[ 3., 4., -4.],
406+
[-3., -2., -1.]])
407+
>>> np.fft.fftshift(freqs, axes=(1,))
408+
array([[ 2., 0., 1.],
409+
[-4., 3., 4.],
410+
[-1., -3., -2.]])
391411
392-
if x_desc.size < 1:
393-
pass # let fallback to handle exception
394-
else:
395-
input_boundarie = x_desc.shape[axis_param]
396-
output_boundarie = input_boundarie
412+
"""
397413

398-
return dpnp_fft_deprecated(
399-
x_desc,
400-
input_boundarie,
401-
output_boundarie,
402-
axis_param,
403-
False,
404-
norm_.value,
405-
).get_pyobj()
414+
dpnp.check_supported_arrays_type(x)
415+
if axes is None:
416+
axes = tuple(range(x.ndim))
417+
shift = [dim // 2 for dim in x.shape]
418+
elif isinstance(axes, int):
419+
shift = x.shape[axes] // 2
420+
else:
421+
x_shape = x.shape
422+
shift = [x_shape[ax] // 2 for ax in axes]
406423

407-
return call_origin(numpy.fft.fftshift, x, axes)
424+
return dpnp.roll(x, shift, axes)
408425

409426

410427
def hfft(x, n=None, axis=-1, norm=None):
@@ -620,48 +637,55 @@ def ifftshift(x, axes=None):
620637
"""
621638
Inverse shift the zero-frequency component to the center of the spectrum.
622639
623-
For full documentation refer to :obj:`numpy.fft.ifftshift`.
640+
Although identical for even-length `x`, the functions differ by one sample
641+
for odd-length `x`.
624642
625-
Limitations
626-
-----------
627-
Parameter `x` is supported either as :class:`dpnp.ndarray`.
628-
Parameter `axes` is unsupported.
629-
Only `dpnp.float64`, `dpnp.float32`, `dpnp.int64`, `dpnp.int32`,
630-
`dpnp.complex128` data types are supported.
631-
Otherwise the function will be executed sequentially on CPU.
632-
633-
"""
643+
For full documentation refer to :obj:`numpy.fft.ifftshift`.
634644
635-
x_desc = dpnp.get_dpnp_descriptor(x, copy_when_nondefault_queue=False)
636-
# TODO: enable implementation
637-
# pylint: disable=condition-evals-to-constant
638-
if x_desc and 0:
639-
norm_ = Norm.backward
645+
Parameters
646+
----------
647+
x : {dpnp.ndarray, usm_ndarray}
648+
Input array.
649+
axes : {None, int, list or tuple of ints}, optional
650+
Axes over which to calculate.
651+
Defaults to ``None``, which shifts all axes.
640652
641-
if axes is None:
642-
axis_param = -1 # the most right dimension (default value)
643-
else:
644-
axis_param = axes
653+
Returns
654+
-------
655+
out : dpnp.ndarray
656+
The shifted array.
645657
646-
input_boundarie = x_desc.shape[axis_param]
658+
See Also
659+
--------
660+
:obj:`dpnp.fft.fftshift` : Shift zero-frequency component to the center
661+
of the spectrum.
647662
648-
if x_desc.size < 1:
649-
pass # let fallback to handle exception
650-
elif input_boundarie < 1:
651-
pass # let fallback to handle exception
652-
else:
653-
output_boundarie = input_boundarie
663+
Examples
664+
--------
665+
>>> import dpnp as np
666+
>>> freqs = np.fft.fftfreq(9, d=1./9).reshape(3, 3)
667+
>>> freqs
668+
array([[ 0., 1., 2.],
669+
[ 3., 4., -4.],
670+
[-3., -2., -1.]])
671+
>>> np.fft.ifftshift(np.fft.fftshift(freqs))
672+
array([[ 0., 1., 2.],
673+
[ 3., 4., -4.],
674+
[-3., -2., -1.]])
654675
655-
return dpnp_fft_deprecated(
656-
x_desc,
657-
input_boundarie,
658-
output_boundarie,
659-
axis_param,
660-
True,
661-
norm_.value,
662-
).get_pyobj()
676+
"""
663677

664-
return call_origin(numpy.fft.ifftshift, x, axes)
678+
dpnp.check_supported_arrays_type(x)
679+
if axes is None:
680+
axes = tuple(range(x.ndim))
681+
shift = [-(dim // 2) for dim in x.shape]
682+
elif isinstance(axes, int):
683+
shift = -(x.shape[axes] // 2)
684+
else:
685+
x_shape = x.shape
686+
shift = [-(x_shape[ax] // 2) for ax in axes]
687+
688+
return dpnp.roll(x, shift, axes)
665689

666690

667691
def ihfft(x, n=None, axis=-1, norm=None):

tests/test_fft.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,14 @@ def test_error(self, func):
372372

373373
# d should be an scalar
374374
assert_raises(ValueError, getattr(dpnp.fft, func), 10, (2,))
375+
376+
377+
class TestFftshift:
378+
@pytest.mark.parametrize("func", ["fftshift", "ifftshift"])
379+
@pytest.mark.parametrize("axes", [None, 1, (0, 1)])
380+
def test_fftshift(self, func, axes):
381+
x = dpnp.arange(12).reshape(3, 4)
382+
x_np = x.asnumpy()
383+
expected = getattr(dpnp.fft, func)(x, axes=axes)
384+
result = getattr(numpy.fft, func)(x_np, axes=axes)
385+
assert_dtype_allclose(expected, result)

tests/test_sycl_queue.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,27 @@ def test_fftfreq(func, device):
12761276
assert result.sycl_device == device
12771277

12781278

1279+
@pytest.mark.parametrize("func", ["fftshift", "ifftshift"])
1280+
@pytest.mark.parametrize(
1281+
"device",
1282+
valid_devices,
1283+
ids=[device.filter_string for device in valid_devices],
1284+
)
1285+
def test_fftshift(func, device):
1286+
dpnp_data = dpnp.fft.fftfreq(10, 0.5, device=device)
1287+
data = dpnp_data.asnumpy()
1288+
1289+
expected = getattr(numpy.fft, func)(data)
1290+
result = getattr(dpnp.fft, func)(dpnp_data)
1291+
1292+
assert_dtype_allclose(result, expected)
1293+
1294+
expected_queue = dpnp_data.get_array().sycl_queue
1295+
result_queue = result.get_array().sycl_queue
1296+
1297+
assert_sycl_queue_equal(result_queue, expected_queue)
1298+
1299+
12791300
@pytest.mark.parametrize(
12801301
"data, is_empty",
12811302
[

tests/test_usm_type.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,17 @@ def test_eigenvalue(func, shape, usm_type):
933933
assert a.usm_type == dp_val.usm_type
934934

935935

936+
@pytest.mark.parametrize("func", ["fft", "ifft"])
937+
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
938+
def test_fft(func, usm_type):
939+
940+
dpnp_data = dp.arange(100, usm_type=usm_type, dtype=dp.complex64)
941+
result = getattr(dp.fft, func)(dpnp_data)
942+
943+
assert dpnp_data.usm_type == usm_type
944+
assert result.usm_type == usm_type
945+
946+
936947
@pytest.mark.parametrize("func", ["fftfreq", "rfftfreq"])
937948
@pytest.mark.parametrize("usm_type", list_of_usm_types + [None])
938949
def test_fftfreq(func, usm_type):
@@ -947,10 +958,10 @@ def test_fftfreq(func, usm_type):
947958
assert result.usm_type == usm_type
948959

949960

950-
@pytest.mark.parametrize("func", ["fft", "ifft"])
961+
@pytest.mark.parametrize("func", ["fftshift", "ifftshift"])
951962
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
952-
def test_fft(func, usm_type):
953-
dpnp_data = dp.arange(100, usm_type=usm_type, dtype=dp.complex64)
963+
def test_fftshift(func, usm_type):
964+
dpnp_data = dp.fft.fftfreq(10, 0.5, usm_type=usm_type)
954965
result = getattr(dp.fft, func)(dpnp_data)
955966

956967
assert dpnp_data.usm_type == usm_type

tests/third_party/cupy/fft_tests/test_fft.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ def test_rfftfreq(self, xp):
380380
{"shape": (10, 10), "axes": 0},
381381
{"shape": (10, 10), "axes": (0, 1)},
382382
)
383-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
384383
class TestFftshift:
385384
@testing.for_all_dtypes()
386385
@testing.numpy_cupy_allclose(

0 commit comments

Comments
 (0)