@@ -257,12 +257,12 @@ def inverse_jax(
257257
258258 # Perform latitudinal wigner-d recursions
259259 @custom_vjp
260- def flm_to_ftm (flm , s , precomps ):
260+ def flm_to_ftm (flm , spin , precomps ):
261261 return otf .inverse_latitudinal_step_jax (
262262 flm ,
263263 thetas ,
264264 L ,
265- s ,
265+ spin ,
266266 nside ,
267267 sampling ,
268268 reality ,
@@ -271,16 +271,16 @@ def flm_to_ftm(flm, s, precomps):
271271 L_lower = L_lower ,
272272 )
273273
274- def f_fwd (flm , s , precomps ):
275- return flm_to_ftm (flm , s , precomps ), ([], s , [])
274+ def f_fwd (flm , spin , precomps ):
275+ return flm_to_ftm (flm , spin , precomps ), ([], spin , [])
276276
277277 def f_bwd (res , gtm ):
278- s = res [1 ]
278+ spin = res [1 ]
279279 glm = otf .forward_latitudinal_step_jax (
280280 gtm ,
281281 thetas ,
282282 L ,
283- s ,
283+ spin ,
284284 nside ,
285285 sampling ,
286286 reality ,
@@ -604,12 +604,12 @@ def forward_jax(
604604
605605 # Perform latitudinal wigner-d recursions
606606 @custom_vjp
607- def ftm_to_flm (ftm , s , precomps ):
607+ def ftm_to_flm (ftm , spin , precomps ):
608608 flm = otf .forward_latitudinal_step_jax (
609609 ftm ,
610610 thetas ,
611611 L ,
612- s ,
612+ spin ,
613613 nside ,
614614 sampling ,
615615 reality ,
@@ -619,16 +619,16 @@ def ftm_to_flm(ftm, s, precomps):
619619 )
620620 return flm
621621
622- def f_fwd (ftm , s , precomps ):
623- return ftm_to_flm (ftm , s , precomps ), ([], s , [])
622+ def f_fwd (ftm , spin , precomps ):
623+ return ftm_to_flm (ftm , spin , precomps ), ([], spin , [])
624624
625625 def f_bwd (res , glm ):
626- s = res [1 ]
626+ spin = res [1 ]
627627 gtm = otf .inverse_latitudinal_step_jax (
628628 glm ,
629629 thetas ,
630630 L ,
631- s ,
631+ spin ,
632632 nside ,
633633 sampling ,
634634 reality ,
0 commit comments