Skip to content

Commit 4fe1afc

Browse files
committed
update test to avoid inplace array updates
1 parent ac9eee8 commit 4fe1afc

File tree

1 file changed

+27
-13
lines changed

1 file changed

+27
-13
lines changed

tests/test_healpix_ffts.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
2+
import healpy as hp
23
import pytest
34
from jax import config
5+
from s2fft.sampling import s2_samples as samples
46
from s2fft.utils.healpix_ffts import (
57
healpix_fft_jax,
68
healpix_fft_numpy,
@@ -12,24 +14,36 @@
1214
config.update("jax_enable_x64", True)
1315

1416

15-
@pytest.mark.parametrize("L", (32, 64))
16-
@pytest.mark.parametrize("nside", (4, 8, 16))
17-
@pytest.mark.parametrize("reality", (True, False))
18-
def test_healpix_fft_jax_numpy_consistency(rng, L, nside, reality):
19-
f = rng.standard_normal(size=12 * nside**2)
17+
nside_to_test = [4, 5]
18+
reality_to_test = [False, True]
19+
20+
21+
@pytest.mark.parametrize("nside", nside_to_test)
22+
@pytest.mark.parametrize("reality", reality_to_test)
23+
def test_healpix_fft_jax_numpy_consistency(flm_generator, nside, reality):
24+
L = 2 * nside
25+
# Generate a random bandlimited signal
26+
flm = flm_generator(L=L, reality=reality)
27+
flm_hp = samples.flm_2d_to_hp(flm, L)
28+
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
29+
# Test consistency
2030
assert np.allclose(
2131
healpix_fft_numpy(f, L, nside, reality), healpix_fft_jax(f, L, nside, reality)
2232
)
2333

2434

25-
@pytest.mark.parametrize("L", (32, 64))
26-
@pytest.mark.parametrize("nside", (4, 8, 16))
27-
@pytest.mark.parametrize("reality", (True, False))
28-
def test_healpix_ifft_jax_numpy_consistency(rng, L, nside, reality):
29-
ftm = healpix_fft_numpy(
30-
rng.standard_normal(size=12 * nside**2), L, nside, reality
31-
)
35+
@pytest.mark.parametrize("nside", nside_to_test)
36+
@pytest.mark.parametrize("reality", reality_to_test)
37+
def test_healpix_ifft_jax_numpy_consistency(flm_generator, nside, reality):
38+
L = 2 * nside
39+
# Generate a random bandlimited signal
40+
flm = flm_generator(L=L, reality=reality)
41+
flm_hp = samples.flm_2d_to_hp(flm, L)
42+
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
43+
ftm = healpix_fft_numpy(f, L, nside, reality)
44+
ftm_copy = np.copy(ftm)
45+
# Test consistency
3246
assert np.allclose(
3347
healpix_ifft_numpy(ftm, L, nside, reality),
34-
healpix_ifft_jax(ftm, L, nside, reality),
48+
healpix_ifft_jax(ftm_copy, L, nside, reality),
3549
)

0 commit comments

Comments
 (0)