@@ -257,7 +257,7 @@ def inverse_jax(
257257
258258 # Perform latitudinal wigner-d recursions
259259 @custom_vjp
260- def func (flm , s , precomps ):
260+ def flm_to_ftm (flm , s , precomps ):
261261 return otf .inverse_latitudinal_step_jax (
262262 flm ,
263263 thetas ,
@@ -272,7 +272,7 @@ def func(flm, s, precomps):
272272 )
273273
274274 def f_fwd (flm , s , precomps ):
275- return func (flm , s , precomps ), (jnp .zeros_like (flm ), s , [])
275+ return flm_to_ftm (flm , s , precomps ), (jnp .zeros_like (flm ), s , [])
276276
277277 def f_bwd (res , gtm ):
278278 s = res [1 ]
@@ -289,8 +289,8 @@ def f_bwd(res, gtm):
289289 )
290290 return glm , None , None
291291
292- func .defvjp (f_fwd , f_bwd )
293- ftm = func (flm , spin , precomps )
292+ flm_to_ftm .defvjp (f_fwd , f_bwd )
293+ ftm = flm_to_ftm (flm , spin , precomps )
294294
295295 # Correct healpix theta row offsets
296296 if sampling .lower () == "healpix" :
@@ -604,7 +604,7 @@ def forward_jax(
604604
605605 # Perform latitudinal wigner-d recursions
606606 @custom_vjp
607- def func (ftm , s , precomps ):
607+ def ftm_to_flm (ftm , s , precomps ):
608608 flm = otf .forward_latitudinal_step_jax (
609609 ftm ,
610610 thetas ,
@@ -620,7 +620,7 @@ def func(ftm, s, precomps):
620620 return flm
621621
622622 def f_fwd (ftm , s , precomps ):
623- return func (ftm , s , precomps ), (jnp .zeros_like (ftm ), s , [])
623+ return ftm_to_flm (ftm , s , precomps ), (jnp .zeros_like (ftm ), s , [])
624624
625625 def f_bwd (res , glm ):
626626 s = res [1 ]
@@ -637,8 +637,8 @@ def f_bwd(res, glm):
637637 )
638638 return gtm , None , None
639639
640- func .defvjp (f_fwd , f_bwd )
641- flm = func (ftm , spin , precomps )
640+ ftm_to_flm .defvjp (f_fwd , f_bwd )
641+ flm = ftm_to_flm (ftm , spin , precomps )
642642
643643 # Apply harmonic normalisation
644644 flm = flm .at [L_lower :].set (
0 commit comments