@@ -307,22 +307,16 @@ def func(f):
307307@pytest .mark .parametrize ("nside" , nside_to_test )
308308@pytest .mark .filterwarnings ("ignore::RuntimeWarning" )
309309def test_healpix_c_backend_inverse_custom_gradients (flm_generator , nside : int ):
310- sampling = "healpix"
311310 L = 2 * nside
312311 reality = True
313312 flm = flm_generator (L = L , reality = reality )
314- flm_target = flm_generator (L = L , reality = reality )
315- f_target = spherical .inverse_jax (
316- flm_target , L , nside = nside , sampling = sampling , reality = reality
317- )
318313
319314 def func (flm ):
320- f = spherical .inverse (
315+ return spherical .inverse (
321316 flm , L , 0 , nside , sampling = "healpix" , method = "jax_healpy" , reality = True
322317 )
323- return jnp .sum (jnp .abs (f - f_target ) ** 2 )
324318
325- check_grads (func , (flm ,), order = 1 , modes = ("rev" ))
319+ check_grads (func , (flm ,), order = 2 , modes = ("fwd" , "rev" ))
326320
327321
328322@pytest .mark .parametrize ("nside" , nside_to_test )
@@ -334,16 +328,12 @@ def test_healpix_c_backend_forward_custom_gradients(
334328 sampling = "healpix"
335329 L = 2 * nside
336330 reality = True
337- flm_target = flm_generator (L = L , reality = reality )
338331 flm = flm_generator (L = L , reality = reality )
339332 f = spherical .inverse_jax (flm , L , nside = nside , sampling = sampling , reality = reality )
340333
341334 def func (f ):
342- flm = spherical .forward (
335+ return spherical .forward (
343336 f , L , nside = nside , sampling = "healpix" , method = "jax_healpy" , iter = iter
344337 )
345- return jnp .sum (jnp .abs (flm - flm_target ) ** 2 )
346-
347- rtol = [1e-6 , 1e-2 , 5e-2 , 1e-2 ][iter ]
348338
349- check_grads (func , (f ,), order = 1 , modes = ("rev" ), rtol = rtol )
339+ check_grads (func , (f ,), order = 2 , modes = ("fwd" , "rev" ) )
0 commit comments