diff --git a/s2fft/transforms/otf_recursions.py b/s2fft/transforms/otf_recursions.py index 50718eb2..bb036e46 100644 --- a/s2fft/transforms/otf_recursions.py +++ b/s2fft/transforms/otf_recursions.py @@ -80,7 +80,14 @@ def inverse_latitudinal_step( half_slices = [el + mm + 1, el - mm + 1] if precomps is None: - precomps = generate_precomputes(L, -mm, sampling, nside, L_lower) + precomps = generate_precomputes( + L=L, + spin=-mm, + sampling=sampling, + nside=nside, + forward=False, + L_lower=L_lower, + ) lrenorm, vsign, cpi, cp2, indices = precomps for i in range(2): diff --git a/tests/test_spherical_transform.py b/tests/test_spherical_transform.py index a4129fba..5a57bc2d 100644 --- a/tests/test_spherical_transform.py +++ b/tests/test_spherical_transform.py @@ -27,6 +27,7 @@ @pytest.mark.parametrize("method", method_to_test) @pytest.mark.parametrize("reality", reality_to_test) @pytest.mark.parametrize("spmd", multiple_gpus) +@pytest.mark.parametrize("use_generate_precomputes", [True, False]) @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_transform_inverse( flm_generator, @@ -37,6 +38,7 @@ def test_transform_inverse( method: str, reality: bool, spmd: bool, + use_generate_precomputes: bool, ): if reality and spin != 0: pytest.skip("Reality only valid for scalar fields (spin=0).") @@ -52,7 +54,10 @@ def test_transform_inverse( Reality=reality, ) - precomps = generate_precomputes(L, spin, sampling, L_lower=L_lower) + if use_generate_precomputes: + precomps = generate_precomputes(L, spin, sampling, L_lower=L_lower) + else: + precomps = None f = spherical.inverse( flm, L, @@ -106,6 +111,7 @@ def test_transform_inverse_healpix( @pytest.mark.parametrize("method", method_to_test) @pytest.mark.parametrize("reality", reality_to_test) @pytest.mark.parametrize("spmd", multiple_gpus) +@pytest.mark.parametrize("use_generate_precomputes", [True, False]) @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_transform_forward( flm_generator, @@ -116,6 +122,7 @@ def test_transform_forward( method: str, reality: bool, spmd: bool, + use_generate_precomputes: bool, ): if reality and spin != 0: pytest.skip("Reality only valid for scalar fields (spin=0).") @@ -131,8 +138,10 @@ def test_transform_forward( Spin=spin, Reality=reality, ) - - precomps = generate_precomputes(L, spin, sampling, None, True, L_lower) + if use_generate_precomputes: + precomps = generate_precomputes(L, spin, sampling, None, True, L_lower) + else: + precomps = None flm_check = spherical.forward( f, L,