Skip to content

Commit 330deae

Browse files
committed
Tests for consistency of HEALPix FFT and IFFT implementations
1 parent 11592c0 commit 330deae

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

tests/test_healpix_ffts.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import numpy as np
2+
import pytest
3+
from jax import config
4+
from s2fft.utils.healpix_ffts import (
5+
healpix_fft_jax,
6+
healpix_fft_numpy,
7+
healpix_ifft_jax,
8+
healpix_ifft_numpy,
9+
)
10+
11+
12+
config.update("jax_enable_x64", True)
13+
14+
15+
@pytest.mark.parametrize("L", (32, 64))
16+
@pytest.mark.parametrize("nside", (4, 8, 16))
17+
@pytest.mark.parametrize("reality", (True, False))
18+
def test_healpix_fft_jax_numpy_consistency(rng, L, nside, reality):
19+
f = rng.standard_normal(size=12 * nside**2)
20+
assert np.allclose(
21+
healpix_fft_numpy(f, L, nside, reality), healpix_fft_jax(f, L, nside, reality)
22+
)
23+
24+
25+
@pytest.mark.parametrize("L", (32, 64))
26+
@pytest.mark.parametrize("nside", (4, 8, 16))
27+
@pytest.mark.parametrize("reality", (True, False))
28+
def test_healpix_fft_ifft_numpy_consistency(rng, L, nside, reality):
29+
f = rng.standard_normal(size=12 * nside**2)
30+
assert np.allclose(
31+
f,
32+
healpix_ifft_numpy(healpix_fft_numpy(f, L, nside, reality), L, nside, reality),
33+
)
34+
35+
36+
@pytest.mark.parametrize("L", (32, 64))
37+
@pytest.mark.parametrize("nside", (4, 8, 16))
38+
@pytest.mark.parametrize("reality", (True, False))
39+
def test_healpix_fft_ifft_jax_consistency(rng, L, nside, reality):
40+
f = rng.standard_normal(size=12 * nside**2)
41+
assert np.allclose(
42+
f, healpix_ifft_jax(healpix_fft_jax(f, L, nside, reality), L, nside, reality)
43+
)
44+
45+
46+
@pytest.mark.parametrize("L", (32, 64))
47+
@pytest.mark.parametrize("nside", (4, 8, 16))
48+
@pytest.mark.parametrize("reality", (True, False))
49+
def test_healpix_ifft_jax_numpy_consistency(rng, L, nside, reality):
50+
ftm = healpix_fft_numpy(
51+
rng.standard_normal(size=12 * nside**2), L, nside, reality
52+
)
53+
assert np.allclose(
54+
healpix_ifft_numpy(ftm, L, nside, reality),
55+
healpix_ifft_jax(ftm, L, nside, reality),
56+
)

0 commit comments

Comments
 (0)