From 3c28fb67740f62a67341437948edbfa9f68988f4 Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 11 Dec 2024 13:15:42 +0000 Subject: [PATCH 1/3] Add `fft` --- cubed/__init__.py | 4 ++-- cubed/array_api/fft.py | 30 ++++++++++++++++++++++++++++++ cubed/tests/test_fft.py | 25 +++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 cubed/array_api/fft.py create mode 100644 cubed/tests/test_fft.py diff --git a/cubed/__init__.py b/cubed/__init__.py index 9134619d0..a87e912fb 100644 --- a/cubed/__init__.py +++ b/cubed/__init__.py @@ -337,6 +337,6 @@ # extensions -from .array_api import linalg +from .array_api import fft, linalg -__all__ += ["linalg"] +__all__ += ["fft", "linalg"] diff --git a/cubed/array_api/fft.py b/cubed/array_api/fft.py new file mode 100644 index 000000000..6429774ce --- /dev/null +++ b/cubed/array_api/fft.py @@ -0,0 +1,30 @@ +from cubed.backend_array_api import namespace as nxp +from cubed.core.ops import map_blocks + + +def fft(x, /, *, n=None, axis=-1, norm="backward"): + if x.numblocks[axis] > 1: + raise ValueError( + "FFT is only supported along axes with a single chunk. " + # TODO: give details about what was tried and mention rechunking (see qr message) + ) + + if n is None: + chunks = x.chunks + else: + chunks = list(x.chunks) + chunks[axis] = (n,) + + return map_blocks( + _fft, + x, + dtype=nxp.complex128, + chunks=chunks, + n=n, + axis=axis, + norm=norm, + ) + + +def _fft(a, n=None, axis=None, norm=None): + return nxp.fft.fft(a, n=n, axis=axis, norm=norm) diff --git a/cubed/tests/test_fft.py b/cubed/tests/test_fft.py new file mode 100644 index 000000000..c2d435702 --- /dev/null +++ b/cubed/tests/test_fft.py @@ -0,0 +1,25 @@ +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +import cubed +import cubed.array_api as xp +from cubed.backend_array_api import namespace as nxp + + +@pytest.mark.parametrize("n", [None, 5, 13]) +def test_fft(n): + an = np.arange(100).reshape(10, 10) + bn = nxp.fft.fft(an, n=n) + a = cubed.from_array(an, chunks=(1, 10)) + b = xp.fft.fft(a, n=n) + + assert_array_equal(b.compute(), bn) + + +def test_fft_chunked_axis_fails(): + an = np.arange(100).reshape(10, 10) + a = cubed.from_array(an, chunks=(1, 10)) + + with pytest.raises(ValueError): + xp.fft.fft(a, axis=0) From 37fee5b68921014481bd4c5b4e19d485fda23bba Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 11 Dec 2024 13:55:41 +0000 Subject: [PATCH 2/3] ifft --- cubed/array_api/fft.py | 15 ++++++++++++--- cubed/tests/test_fft.py | 15 +++++++++------ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/cubed/array_api/fft.py b/cubed/array_api/fft.py index 6429774ce..d690e4221 100644 --- a/cubed/array_api/fft.py +++ b/cubed/array_api/fft.py @@ -3,9 +3,17 @@ def fft(x, /, *, n=None, axis=-1, norm="backward"): + return fft_1d(nxp.fft.fft, x, n=n, axis=axis, norm=norm) + + +def ifft(x, /, *, n=None, axis=-1, norm="backward"): + return fft_1d(nxp.fft.ifft, x, n=n, axis=axis, norm=norm) + + +def fft_1d(fft_func, x, /, *, n=None, axis=-1, norm="backward"): if x.numblocks[axis] > 1: raise ValueError( - "FFT is only supported along axes with a single chunk. " + "FFT can only be applied along axes with a single chunk. " # TODO: give details about what was tried and mention rechunking (see qr message) ) @@ -20,11 +28,12 @@ def fft(x, /, *, n=None, axis=-1, norm="backward"): x, dtype=nxp.complex128, chunks=chunks, + fft_func=fft_func, n=n, axis=axis, norm=norm, ) -def _fft(a, n=None, axis=None, norm=None): - return nxp.fft.fft(a, n=n, axis=axis, norm=norm) +def _fft(a, fft_func=None, n=None, axis=None, norm=None): + return fft_func(a, n=n, axis=axis, norm=norm) diff --git a/cubed/tests/test_fft.py b/cubed/tests/test_fft.py index c2d435702..7debe33eb 100644 --- a/cubed/tests/test_fft.py +++ b/cubed/tests/test_fft.py @@ -1,4 +1,3 @@ -import numpy as np import pytest from numpy.testing import assert_array_equal @@ -7,18 +6,22 @@ from cubed.backend_array_api import namespace as nxp +@pytest.mark.parametrize("funcname", ["fft", "ifft"]) @pytest.mark.parametrize("n", [None, 5, 13]) -def test_fft(n): - an = np.arange(100).reshape(10, 10) - bn = nxp.fft.fft(an, n=n) +def test_fft(funcname, n): + nxp_fft = getattr(nxp.fft, funcname) + cb_fft = getattr(xp.fft, funcname) + + an = nxp.arange(100).reshape(10, 10) + bn = nxp_fft(an, n=n) a = cubed.from_array(an, chunks=(1, 10)) - b = xp.fft.fft(a, n=n) + b = cb_fft(a, n=n) assert_array_equal(b.compute(), bn) def test_fft_chunked_axis_fails(): - an = np.arange(100).reshape(10, 10) + an = nxp.arange(100).reshape(10, 10) a = cubed.from_array(an, chunks=(1, 10)) with pytest.raises(ValueError): From fd7fdf3f165d54575317357248b130587162c8be Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 20 Jan 2025 09:50:50 +0000 Subject: [PATCH 3/3] Use assert_allclose --- cubed/tests/test_fft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cubed/tests/test_fft.py b/cubed/tests/test_fft.py index 7debe33eb..3b408538d 100644 --- a/cubed/tests/test_fft.py +++ b/cubed/tests/test_fft.py @@ -1,5 +1,5 @@ import pytest -from numpy.testing import assert_array_equal +from numpy.testing import assert_allclose import cubed import cubed.array_api as xp @@ -17,7 +17,7 @@ def test_fft(funcname, n): a = cubed.from_array(an, chunks=(1, 10)) b = cb_fft(a, n=n) - assert_array_equal(b.compute(), bn) + assert_allclose(b.compute(), bn) def test_fft_chunked_axis_fails():