Skip to content

Commit 7e74a0d

Browse files
committed
rename func in custom_vjp
1 parent 1f07ccb commit 7e74a0d

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

s2fft/transforms/spherical.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def inverse_jax(
257257

258258
# Perform latitudinal wigner-d recursions
259259
@custom_vjp
260-
def func(flm, s, precomps):
260+
def flm_to_ftm(flm, s, precomps):
261261
return otf.inverse_latitudinal_step_jax(
262262
flm,
263263
thetas,
@@ -272,7 +272,7 @@ def func(flm, s, precomps):
272272
)
273273

274274
def f_fwd(flm, s, precomps):
275-
return func(flm, s, precomps), (jnp.zeros_like(flm), s, [])
275+
return flm_to_ftm(flm, s, precomps), (jnp.zeros_like(flm), s, [])
276276

277277
def f_bwd(res, gtm):
278278
s = res[1]
@@ -289,8 +289,8 @@ def f_bwd(res, gtm):
289289
)
290290
return glm, None, None
291291

292-
func.defvjp(f_fwd, f_bwd)
293-
ftm = func(flm, spin, precomps)
292+
flm_to_ftm.defvjp(f_fwd, f_bwd)
293+
ftm = flm_to_ftm(flm, spin, precomps)
294294

295295
# Correct healpix theta row offsets
296296
if sampling.lower() == "healpix":
@@ -604,7 +604,7 @@ def forward_jax(
604604

605605
# Perform latitudinal wigner-d recursions
606606
@custom_vjp
607-
def func(ftm, s, precomps):
607+
def ftm_to_flm(ftm, s, precomps):
608608
flm = otf.forward_latitudinal_step_jax(
609609
ftm,
610610
thetas,
@@ -620,7 +620,7 @@ def func(ftm, s, precomps):
620620
return flm
621621

622622
def f_fwd(ftm, s, precomps):
623-
return func(ftm, s, precomps), (jnp.zeros_like(ftm), s, [])
623+
return ftm_to_flm(ftm, s, precomps), (jnp.zeros_like(ftm), s, [])
624624

625625
def f_bwd(res, glm):
626626
s = res[1]
@@ -637,8 +637,8 @@ def f_bwd(res, glm):
637637
)
638638
return gtm, None, None
639639

640-
func.defvjp(f_fwd, f_bwd)
641-
flm = func(ftm, spin, precomps)
640+
ftm_to_flm.defvjp(f_fwd, f_bwd)
641+
flm = ftm_to_flm(ftm, spin, precomps)
642642

643643
# Apply harmonic normalisation
644644
flm = flm.at[L_lower:].set(

0 commit comments

Comments
 (0)