Skip to content

Commit 4139c39

Browse files
committed
Test across different iter values
1 parent 2f125b4 commit 4139c39

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

tests/test_spherical_custom_grads.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,15 @@ def func(flm):
161161
@pytest.mark.parametrize("L_lower", L_lower_to_test)
162162
@pytest.mark.parametrize("spin", spin_to_test)
163163
@pytest.mark.parametrize("reality", reality_to_test)
164+
@pytest.mark.parametrize("iter", [0, 1, 2, 3])
164165
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
165166
def test_healpix_forward_custom_gradients(
166167
flm_generator,
167168
nside: int,
168169
L_lower: int,
169170
spin: int,
170171
reality: bool,
172+
iter: int,
171173
):
172174
sampling = "healpix"
173175
L = 2 * nside
@@ -191,15 +193,17 @@ def test_healpix_forward_custom_gradients(
191193
)
192194

193195
def func(f):
194-
flm = spherical.forward_jax(
196+
flm = spherical.forward(
195197
f,
196198
L,
199+
method="jax",
197200
spin=spin,
198201
nside=nside,
199202
L_lower=L_lower,
200203
reality=reality,
201204
precomps=precomps,
202205
sampling=sampling,
206+
iter=iter,
203207
)
204208
return jnp.sum(jnp.abs(flm - flm_target) ** 2)
205209

tests/test_spherical_transform.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,14 @@ def test_transform_forward(
150150
@pytest.mark.parametrize("nside", nside_to_test)
151151
@pytest.mark.parametrize("method", method_to_test)
152152
@pytest.mark.parametrize("spmd", multiple_gpus)
153+
@pytest.mark.parametrize("iter", [0, 1, 2, 3])
153154
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
154155
def test_transform_forward_healpix(
155156
flm_generator,
156157
nside: int,
157158
method: str,
158159
spmd: bool,
160+
iter: int,
159161
):
160162
sampling = "healpix"
161163
L = 2 * nside
@@ -174,10 +176,11 @@ def test_transform_forward_healpix(
174176
reality=True,
175177
precomps=precomps,
176178
spmd=spmd,
179+
iter=iter,
177180
)
178181
flm_check = samples.flm_2d_to_hp(flm_check, L)
179182

180-
flm = hp.sphtfunc.map2alm(f, lmax=L - 1, iter=0)
183+
flm = hp.sphtfunc.map2alm(f, lmax=L - 1, iter=iter)
181184

182185
np.testing.assert_allclose(flm, flm_check, atol=1e-14)
183186

0 commit comments

Comments
 (0)