|
6 | 6 | import numpy as np |
7 | 7 | from s2fft.sampling import s2_samples as samples |
8 | 8 | from s2fft.utils.augmentation import rotate_flms, generate_rotate_dls |
| 9 | +import jax.numpy as jnp |
| 10 | +from jax.test_util import check_grads |
9 | 11 |
|
10 | 12 | L_to_test = [6, 8, 10] |
11 | 13 | angles_to_test = [np.pi / 2, np.pi / 6] |
@@ -96,3 +98,24 @@ def test_rotate_flms_precompute_dls( |
96 | 98 | flm_rot_s2fft = rotate_flms(flm, L, rot, dl) |
97 | 99 |
|
98 | 100 | np.testing.assert_allclose(flm_rot_ssht, flm_rot_s2fft, atol=1e-14) |
| 101 | + |
| 102 | + |
| 103 | +@pytest.mark.parametrize("L", L_to_test) |
| 104 | +@pytest.mark.parametrize("alpha", angles_to_test) |
| 105 | +@pytest.mark.parametrize("beta", angles_to_test) |
| 106 | +@pytest.mark.parametrize("gamma", angles_to_test) |
| 107 | +@pytest.mark.filterwarnings("ignore::RuntimeWarning") |
| 108 | +@pytest.mark.filterwarnings("ignore::FutureWarning") |
| 109 | +def test_rotate_flms_gradients( |
| 110 | + flm_generator, L: int, alpha: float, beta: float, gamma: float |
| 111 | +): |
| 112 | + flm_start = flm_generator(L=L) |
| 113 | + |
| 114 | + rot = (alpha, beta, gamma) |
| 115 | + flm_target = rotate_flms(flm_start, L, (0.1, 0.1, 0.1)) |
| 116 | + |
| 117 | + def func(flm): |
| 118 | + flm_rot = rotate_flms(flm, L, rot) |
| 119 | + return jnp.sum(jnp.abs(flm_rot - flm_target)) |
| 120 | + |
| 121 | + check_grads(func, (flm_start,), order=1, modes=("rev")) |
0 commit comments