@@ -74,9 +74,7 @@ def spin_spherical_kernel(
7474 if recursion .lower () == "auto" :
7575 # This mode automatically determines which recursion is best suited for the
7676 # current parameter configuration.
77- recursion = (
78- "risbo" if abs (spin ) >= PM_MAX_STABLE_SPIN else "price-mcewen"
79- )
77+ recursion = "risbo" if abs (spin ) >= PM_MAX_STABLE_SPIN else "price-mcewen"
8078
8179 dl = []
8280 m_start_ind = L - 1 if reality else 0
@@ -111,13 +109,9 @@ def spin_spherical_kernel(
111109 # - The complexity of this approach is O(L^4).
112110 # - This approach is stable for arbitrary abs(spins) <= L.
113111 if sampling .lower () in ["healpix" , "gl" ]:
114- delta = np .zeros (
115- (len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = np .float64
116- )
112+ delta = np .zeros ((len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
117113 for el in range (L ):
118- delta = recursions .risbo .compute_full_vectorised (
119- delta , thetas , L , el
120- )
114+ delta = recursions .risbo .compute_full_vectorised (delta , thetas , L , el )
121115 dl [:, el ] = delta [:, m_start_ind :, L - 1 - spin ]
122116
123117 # MW, MWSS, and DH sampling ARE uniform in theta therefore CAN be calculated
@@ -144,19 +138,13 @@ def spin_spherical_kernel(
144138 delta [:, L - 1 - spin ],
145139 1j ** (- spin - m_value [m_start_ind :]),
146140 )
147- temp = np .einsum (
148- "am,a->am" , temp , np .exp (1j * m_value * thetas [0 ])
149- )
150- temp = np .fft .irfft (
151- temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward"
152- )
141+ temp = np .einsum ("am,a->am" , temp , np .exp (1j * m_value * thetas [0 ]))
142+ temp = np .fft .irfft (temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward" )
153143
154144 dl [:, el ] = temp [: len (thetas )]
155145
156146 # Fold in normalisation to avoid recomputation at run-time.
157- dl = np .einsum (
158- "tlm,l->tlm" , dl , np .sqrt ((2 * np .arange (L ) + 1 ) / (4 * np .pi ))
159- )
147+ dl = np .einsum ("tlm,l->tlm" , dl , np .sqrt ((2 * np .arange (L ) + 1 ) / (4 * np .pi )))
160148
161149 else :
162150 raise ValueError (f"Recursion method { recursion } not recognised." )
@@ -234,9 +222,7 @@ def spin_spherical_kernel_jax(
234222 if recursion .lower () == "auto" :
235223 # This mode automatically determines which recursion is best suited for the
236224 # current parameter configuration.
237- recursion = (
238- "risbo" if abs (spin ) >= PM_MAX_STABLE_SPIN else "price-mcewen"
239- )
225+ recursion = "risbo" if abs (spin ) >= PM_MAX_STABLE_SPIN else "price-mcewen"
240226
241227 dl = []
242228 m_start_ind = L - 1 if reality else 0
@@ -283,9 +269,7 @@ def spin_spherical_kernel_jax(
283269 # - The complexity of this approach is O(L^4).
284270 # - This approach is stable for arbitrary abs(spins) <= L.
285271 if sampling .lower () in ["healpix" , "gl" ]:
286- delta = jnp .zeros (
287- (len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64
288- )
272+ delta = jnp .zeros ((len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
289273 vfunc = jax .vmap (
290274 recursions .risbo_jax .compute_full , in_axes = (0 , 0 , None , None )
291275 )
@@ -309,32 +293,24 @@ def spin_spherical_kernel_jax(
309293
310294 # Calculate the Fourier coefficients of the Wigner d-functions, delta(pi/2).
311295 for el in range (L ):
312- delta = recursions .risbo_jax .compute_full (
313- delta , jnp .pi / 2 , L , el
314- )
296+ delta = recursions .risbo_jax .compute_full (delta , jnp .pi / 2 , L , el )
315297 m_value = jnp .arange (- L + 1 , L )
316298 temp = jnp .einsum (
317299 "am,a,m->am" ,
318300 delta [:, m_start_ind :],
319301 delta [:, L - 1 - spin ],
320302 1j ** (- spin - m_value [m_start_ind :]),
321303 )
322- temp = jnp .einsum (
323- "am,a->am" , temp , jnp .exp (1j * m_value * thetas [0 ])
324- )
325- temp = jnp .fft .irfft (
326- temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward"
327- )
304+ temp = jnp .einsum ("am,a->am" , temp , jnp .exp (1j * m_value * thetas [0 ]))
305+ temp = jnp .fft .irfft (temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward" )
328306
329307 dl = dl .at [:, el ].set (temp [: len (thetas )])
330308
331309 else :
332310 raise ValueError (f"Recursion method { recursion } not recognised." )
333311
334312 # Fold in normalisation to avoid recomputation at run-time.
335- dl = jnp .einsum (
336- "tlm,l->tlm" , dl , jnp .sqrt ((2 * jnp .arange (L ) + 1 ) / (4 * jnp .pi ))
337- )
313+ dl = jnp .einsum ("tlm,l->tlm" , dl , jnp .sqrt ((2 * jnp .arange (L ) + 1 ) / (4 * jnp .pi )))
338314
339315 # Fold in quadrature to avoid recomputation at run-time.
340316 if forward :
@@ -433,9 +409,7 @@ def wigner_kernel(
433409 if mode .lower () == "direct" :
434410 delta = np .zeros ((len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
435411 for el in range (L ):
436- delta = recursions .risbo .compute_full_vectorised (
437- delta , thetas , L , el
438- )
412+ delta = recursions .risbo .compute_full_vectorised (delta , thetas , L , el )
439413 dl [:, :, el ] = np .moveaxis (delta , - 1 , 0 )[L - 1 + n ]
440414
441415 # MW, MWSS, and DH sampling ARE uniform in theta therefore CAN be calculated
@@ -464,9 +438,7 @@ def wigner_kernel(
464438 1j ** (- m_value ),
465439 1j ** (n ),
466440 )
467- temp = np .einsum (
468- "amn,a->amn" , temp , np .exp (1j * m_value * thetas [0 ])
469- )
441+ temp = np .einsum ("amn,a->amn" , temp , np .exp (1j * m_value * thetas [0 ]))
470442 temp = np .fft .irfft (temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward" )
471443 dl [:, :, el ] = np .moveaxis (temp [: len (thetas )], - 1 , 0 )
472444
@@ -574,12 +546,8 @@ def wigner_kernel_jax(
574546 # - The complexity of this approach is ALWAYS O(L^4).
575547 # - This approach is stable for arbitrary abs(spins) <= L.
576548 if mode .lower () == "direct" :
577- delta = jnp .zeros (
578- (len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64
579- )
580- vfunc = jax .vmap (
581- recursions .risbo_jax .compute_full , in_axes = (0 , 0 , None , None )
582- )
549+ delta = jnp .zeros ((len (thetas ), 2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
550+ vfunc = jax .vmap (recursions .risbo_jax .compute_full , in_axes = (0 , 0 , None , None ))
583551 for el in range (L ):
584552 delta = vfunc (delta , thetas , L , el )
585553 dl = dl .at [:, :, el ].set (jnp .moveaxis (delta , - 1 , 0 )[L - 1 + n ])
@@ -610,12 +578,8 @@ def wigner_kernel_jax(
610578 1j ** (- m_value ),
611579 1j ** (n ),
612580 )
613- temp = jnp .einsum (
614- "amn,a->amn" , temp , jnp .exp (1j * m_value * thetas [0 ])
615- )
616- temp = jnp .fft .irfft (
617- temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward"
618- )
581+ temp = jnp .einsum ("amn,a->amn" , temp , jnp .exp (1j * m_value * thetas [0 ]))
582+ temp = jnp .fft .irfft (temp [L - 1 :], n = nsamps , axis = 0 , norm = "forward" )
619583 dl = dl .at [:, :, el ].set (jnp .moveaxis (temp [: len (thetas )], - 1 , 0 ))
620584
621585 else :
@@ -646,9 +610,7 @@ def wigner_kernel_jax(
646610 return dl
647611
648612
649- def healpix_phase_shifts (
650- L : int , nside : int , forward : bool = False
651- ) -> np .ndarray :
613+ def healpix_phase_shifts (L : int , nside : int , forward : bool = False ) -> np .ndarray :
652614 r"""
653615 Generates a phase shift vector for HEALPix for all :math:`\theta` rings.
654616
0 commit comments