Skip to content

Commit 81891d1

Browse files
committed
add gradient finite differences tests for rotation functions
1 parent ec20bff commit 81891d1

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/test_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy as np
77
from s2fft.sampling import s2_samples as samples
88
from s2fft.utils.augmentation import rotate_flms, generate_rotate_dls
9+
import jax.numpy as jnp
10+
from jax.test_util import check_grads
911

1012
L_to_test = [6, 8, 10]
1113
angles_to_test = [np.pi / 2, np.pi / 6]
@@ -96,3 +98,24 @@ def test_rotate_flms_precompute_dls(
9698
flm_rot_s2fft = rotate_flms(flm, L, rot, dl)
9799

98100
np.testing.assert_allclose(flm_rot_ssht, flm_rot_s2fft, atol=1e-14)
101+
102+
103+
@pytest.mark.parametrize("L", L_to_test)
104+
@pytest.mark.parametrize("alpha", angles_to_test)
105+
@pytest.mark.parametrize("beta", angles_to_test)
106+
@pytest.mark.parametrize("gamma", angles_to_test)
107+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
108+
@pytest.mark.filterwarnings("ignore::FutureWarning")
109+
def test_rotate_flms_gradients(
110+
flm_generator, L: int, alpha: float, beta: float, gamma: float
111+
):
112+
flm_start = flm_generator(L=L)
113+
114+
rot = (alpha, beta, gamma)
115+
flm_target = rotate_flms(flm_start, L, (0.1, 0.1, 0.1))
116+
117+
def func(flm):
118+
flm_rot = rotate_flms(flm, L, rot)
119+
return jnp.sum(jnp.abs(flm_rot - flm_target))
120+
121+
check_grads(func, (flm_start,), order=1, modes=("rev"))

0 commit comments

Comments
 (0)