|
| 1 | +from jax.config import config |
| 2 | + |
| 3 | +config.update("jax_enable_x64", True) |
1 | 4 | import pytest |
| 5 | +import pyssht as ssht |
2 | 6 | import numpy as np |
3 | 7 | from s2fft.sampling import s2_samples as samples |
| 8 | +from s2fft.utils.augmentation import rotate_flms, generate_rotate_dls |
| 9 | + |
| 10 | +L_to_test = [6, 8, 10] |
| 11 | +angles_to_test = [np.pi / 2, np.pi / 6] |
4 | 12 |
|
5 | 13 |
|
6 | 14 | def test_flm_reindexing_functions(flm_generator): |
@@ -47,3 +55,44 @@ def test_flm_reindexing_exceptions(flm_generator): |
47 | 55 |
|
48 | 56 | with pytest.raises(ValueError) as e: |
49 | 57 | samples.flm_1d_to_2d(flm_3d, L) |
| 58 | + |
| 59 | + |
| 60 | +@pytest.mark.parametrize("L", L_to_test) |
| 61 | +@pytest.mark.parametrize("alpha", angles_to_test) |
| 62 | +@pytest.mark.parametrize("beta", angles_to_test) |
| 63 | +@pytest.mark.parametrize("gamma", angles_to_test) |
| 64 | +@pytest.mark.filterwarnings("ignore::RuntimeWarning") |
| 65 | +@pytest.mark.filterwarnings("ignore::FutureWarning") |
| 66 | +def test_rotate_flms(flm_generator, L: int, alpha: float, beta: float, gamma: float): |
| 67 | + flm = flm_generator(L=L) |
| 68 | + rot = (alpha, beta, gamma) |
| 69 | + flm_1d = samples.flm_2d_to_1d(flm, L) |
| 70 | + |
| 71 | + flm_rot_ssht = samples.flm_1d_to_2d( |
| 72 | + ssht.rotate_flms(flm_1d, alpha, beta, gamma, L), L |
| 73 | + ) |
| 74 | + flm_rot_s2fft = rotate_flms(flm, L, rot) |
| 75 | + |
| 76 | + np.testing.assert_allclose(flm_rot_ssht, flm_rot_s2fft, atol=1e-14) |
| 77 | + |
| 78 | + |
| 79 | +@pytest.mark.parametrize("L", L_to_test) |
| 80 | +@pytest.mark.parametrize("alpha", angles_to_test) |
| 81 | +@pytest.mark.parametrize("beta", angles_to_test) |
| 82 | +@pytest.mark.parametrize("gamma", angles_to_test) |
| 83 | +@pytest.mark.filterwarnings("ignore::RuntimeWarning") |
| 84 | +@pytest.mark.filterwarnings("ignore::FutureWarning") |
| 85 | +def test_rotate_flms_precompute_dls( |
| 86 | + flm_generator, L: int, alpha: float, beta: float, gamma: float |
| 87 | +): |
| 88 | + dl = generate_rotate_dls(L, beta) |
| 89 | + flm = flm_generator(L=L) |
| 90 | + rot = (alpha, beta, gamma) |
| 91 | + flm_1d = samples.flm_2d_to_1d(flm, L) |
| 92 | + |
| 93 | + flm_rot_ssht = samples.flm_1d_to_2d( |
| 94 | + ssht.rotate_flms(flm_1d, alpha, beta, gamma, L), L |
| 95 | + ) |
| 96 | + flm_rot_s2fft = rotate_flms(flm, L, rot, dl) |
| 97 | + |
| 98 | + np.testing.assert_allclose(flm_rot_ssht, flm_rot_s2fft, atol=1e-14) |
0 commit comments