11import numpy as np
22import jax .numpy as jnp
33import jax .lax as lax
4- from jax import jit , pmap , local_device_count
4+ from jax import jit , pmap , local_device_count , custom_vjp
55from functools import partial
66from typing import List
77from s2fft .sampling import s2_samples as samples
@@ -225,7 +225,8 @@ def inverse_latitudinal_step_jax(
225225
226226 mm = - spin # switch to match convention
227227 ntheta = len (beta ) # Number of theta samples
228- ftm = jnp .zeros (samples .ftm_shape (L , sampling , nside ), dtype = jnp .complex128 )
228+ m_count = 2 * L if sampling .lower () in ["mwss" , "healpix" ] else 2 * L - 1
229+ ftm = jnp .zeros ((ntheta , m_count ), dtype = jnp .complex128 )
229230 el = jnp .arange (L_lower , L )
230231
231232 # Trigonometric constant adopted throughout
@@ -239,7 +240,7 @@ def inverse_latitudinal_step_jax(
239240
240241 if precomps is None :
241242 precomps = generate_precomputes_jax (
242- L , - mm , sampling , nside , L_lower = L_lower
243+ L , - mm , sampling , nside , L_lower = L_lower , betas = beta
243244 )
244245 lrenorm , vsign , cpi , cp2 , indices = precomps
245246
@@ -401,6 +402,25 @@ def eval_recursion_step(
401402 pm_recursion_step ,
402403 (ftm , dl_entry , dl_iter , lrenorm , indices , omc , c , s ),
403404 )
405+
406+ # Remove south pole singularity
407+ m_offset = 1 if sampling .lower () in ["mwss" , "healpix" ] else 0
408+ if sampling .lower () in ["mw" , "mwss" ]:
409+ ftm = ftm .at [- 1 ].set (0 )
410+ ftm = ftm .at [- 1 , L - 1 + spin + m_offset ].set (
411+ jnp .nansum (
412+ (- 1 ) ** abs (jnp .arange (L_lower , L ) - spin )
413+ * flm [L_lower :, L - 1 + spin ]
414+ )
415+ )
416+
417+ # Remove north pole singularity
418+ if sampling .lower () == "mwss" :
419+ ftm = ftm .at [0 ].set (0 )
420+ ftm = ftm .at [0 , L - 1 - spin + m_offset ].set (
421+ jnp .nansum (flm [L_lower :, L - 1 - spin ])
422+ )
423+
404424 return ftm
405425
406426
@@ -562,8 +582,8 @@ def forward_latitudinal_step(
562582
563583@partial (jit , static_argnums = (2 , 4 , 5 , 6 , 8 , 9 ))
564584def forward_latitudinal_step_jax (
565- ftm : jnp .ndarray ,
566- beta : jnp .ndarray ,
585+ ftm_in : jnp .ndarray ,
586+ beta_in : jnp .ndarray ,
567587 L : int ,
568588 spin : int ,
569589 nside : int ,
@@ -622,6 +642,16 @@ def forward_latitudinal_step_jax(
622642 between devices is noticable, however as L increases one will asymptotically
623643 recover acceleration by the number of devices.
624644 """
645+ # Avoid pole-singularities for MWSS sampling
646+ if sampling .lower () == "mwss" :
647+ ftm = ftm_in [1 :- 1 ]
648+ beta = beta_in [1 :- 1 ]
649+ elif sampling .lower () == "mw" :
650+ ftm = ftm_in [:- 1 ]
651+ beta = beta_in [:- 1 ]
652+ else :
653+ ftm = ftm_in
654+ beta = beta_in
625655
626656 mm = - spin # switch to match convention
627657 ntheta = len (beta ) # Number of theta samples
@@ -639,7 +669,7 @@ def forward_latitudinal_step_jax(
639669
640670 if precomps is None :
641671 precomps = generate_precomputes_jax (
642- L , - mm , sampling , nside , True , L_lower
672+ L , - mm , sampling , nside , True , L_lower , betas = beta
643673 )
644674 lrenorm , vsign , cpi , cp2 , indices = precomps
645675
@@ -840,4 +870,17 @@ def eval_recursion_step(
840870 ),
841871 )[0 ]
842872 )
873+
874+ # Include both pole singularities explicitly
875+ m_offset = 1 if sampling .lower () in ["mwss" , "healpix" ] else 0
876+ if sampling .lower () in ["mw" , "mwss" ]:
877+ flm = flm .at [L_lower :, L - 1 + spin ].add (
878+ (- 1 ) ** abs (jnp .arange (L_lower , L ) - spin )
879+ * ftm_in [- 1 , L - 1 + spin + m_offset ]
880+ )
881+
882+ if sampling .lower () == "mwss" :
883+ flm = flm .at [L_lower :, L - 1 - spin ].add (
884+ ftm_in [0 , L - 1 - spin + m_offset ]
885+ )
843886 return flm
0 commit comments