44
55# C backend functions for which to provide JAX frontend.
66import pyssht
7- from jax import custom_vjp
7+ from jax import core , custom_vjp
8+ from jax .interpreters import ad
89
910from s2fft .sampling import reindex
1011from s2fft .utils import quadrature_jax
@@ -241,83 +242,181 @@ def _ssht_forward_bwd(res, flm):
241242 return f , None , None , None , None , None
242243
243244
244- @custom_vjp
245- def healpy_inverse (flm : jnp .ndarray , L : int , nside : int ) -> jnp .ndarray :
246- r"""
247- Compute the inverse scalar real spherical harmonic transform (HEALPix JAX).
245+ # Link JAX gradients for C backend functions
246+ ssht_inverse .defvjp (_ssht_inverse_fwd , _ssht_inverse_bwd )
247+ ssht_forward .defvjp (_ssht_forward_fwd , _ssht_forward_bwd )
248248
249- HEALPix is a C++ library which implements the scalar spherical harmonic transform
250- outlined in [1]. We make use of their healpy python bindings for which we provide
251- custom JAX frontends, hence providing support for automatic differentiation. Currently
252- these transforms can only be deployed on CPU, which is a limitation of the C++ library.
253249
254- Args:
255- flm (jnp.ndarray): Spherical harmonic coefficients.
250+ def _complex_dtype (real_dtype ):
251+ """
252+ Get complex datatype corresponding to a given real datatype.
256253
257- L (int): Harmonic band-limit.
254+ Derived from https://github.com/jax-ml/jax/blob/1471702adc28/jax/_src/lax/fft.py#L92
258255
259- nside (int, optional): HEALPix Nside resolution parameter. Only required
260- if sampling="healpix". Defaults to None.
256+ Original license:
261257
262- Returns:
263- jnp.ndarray: Signal on the sphere.
258+ Copyright 2019 The JAX Authors.
264259
265- Note:
266- [1] Gorski, Krzysztof M., et al. "HEALPix: A framework for high-resolution
267- discretization and fast analysis of data distributed on the sphere." The
268- Astrophysical Journal 622.2 (2005): 759
260+ Licensed under the Apache License, Version 2.0 (the "License");
261+ you may not use this file except in compliance with the License.
262+ You may obtain a copy of the License at
263+
264+ https://www.apache.org/licenses/LICENSE-2.0
269265
266+ Unless required by applicable law or agreed to in writing, software
267+ distributed under the License is distributed on an "AS IS" BASIS,
268+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
269+ See the License for the specific language governing permissions and
270+ limitations under the License.
270271 """
271- flm = reindex .flm_2d_to_hp_fast (flm , L )
272- f = jnp .array (healpy .alm2map (np .array (flm ), lmax = L - 1 , nside = nside ))
273- return f
272+ return (np .zeros ((), real_dtype ) + np .zeros ((), np .complex64 )).dtype
273+
274+
275+ def _real_dtype (complex_dtype ):
276+ """
277+ Get real datatype corresponding to a given complex datatype.
278+
279+ Derived from https://github.com/jax-ml/jax/blob/1471702adc28/jax/_src/lax/fft.py#L93
280+
281+ Original license:
282+
283+ Copyright 2019 The JAX Authors.
284+
285+ Licensed under the Apache License, Version 2.0 (the "License");
286+ you may not use this file except in compliance with the License.
287+ You may obtain a copy of the License at
288+
289+ https://www.apache.org/licenses/LICENSE-2.0
274290
291+ Unless required by applicable law or agreed to in writing, software
292+ distributed under the License is distributed on an "AS IS" BASIS,
293+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
294+ See the License for the specific language governing permissions and
295+ limitations under the License.
296+ """
297+ return np .finfo (complex_dtype ).dtype
298+
299+
300+ def _healpy_map2alm_impl (f : jnp .ndarray , L : int , nside : int ) -> jnp .ndarray :
301+ return jnp .array (healpy .map2alm (np .array (f ), lmax = L - 1 , iter = 0 ))
275302
276- def _healpy_inverse_fwd (flm : jnp .ndarray , L : int , nside : int ):
277- """Private function which implements the forward pass for inverse jax_healpy."""
278- res = ([], L , nside )
279- return healpy_inverse (flm , L , nside ), res
280303
304+ def _healpy_map2alm_abstract_eval (
305+ f : core .ShapedArray , L : int , nside : int
306+ ) -> core .ShapedArray :
307+ return core .ShapedArray (shape = (L * (L + 1 ) // 2 ,), dtype = _complex_dtype (f .dtype ))
281308
282- def _healpy_inverse_bwd ( res , f ):
283- """Private function which implements the backward pass for inverse jax_healpy."""
284- _ , L , nside = res
285- f_new = f * ( 12 * nside ** 2 ) / ( 4 * jnp . pi )
286- flm_out = jnp . array (
287- np . conj ( healpy . map2alm ( np . conj ( np . array ( f_new )), lmax = L - 1 , iter = 0 ))
309+
310+ def _healpy_map2alm_transpose ( dflm : jnp . ndarray , L : int , nside : int ):
311+ scale_factors = (
312+ jnp . concatenate (( jnp . ones ( L ), 2 * jnp . ones ( L * ( L - 1 ) // 2 )) )
313+ * ( 3 * nside ** 2 )
314+ / jnp . pi
288315 )
289- # iter MUST be zero otherwise gradient propagation is incorrect (JDM).
290- flm_out = reindex .flm_hp_to_2d_fast (flm_out , L )
291- m_conj = (- 1 ) ** (jnp .arange (1 , L ) % 2 )
292- flm_out = flm_out .at [..., L :].add (
293- jnp .flip (m_conj * jnp .conj (flm_out [..., : L - 1 ]), axis = - 1 )
316+ return (jnp .conj (healpy_alm2map (jnp .conj (dflm ) / scale_factors , L , nside )),)
317+
318+
319+ _healpy_map2alm_p = core .Primitive ("healpy_map2alm" )
320+ _healpy_map2alm_p .def_impl (_healpy_map2alm_impl )
321+ _healpy_map2alm_p .def_abstract_eval (_healpy_map2alm_abstract_eval )
322+ ad .deflinear (_healpy_map2alm_p , _healpy_map2alm_transpose )
323+
324+
325+ def healpy_map2alm (f : jnp .ndarray , L : int , nside : int ) -> jnp .ndarray :
326+ """
327+ JAX wrapper for healpy map2alm function (forward spherical harmonic transform).
328+
329+ This wrapper will return the spherical harmonic coefficients as a one dimensional
330+ array using HEALPix (ring-ordered) indexing. To instead return a two-dimensional
331+ array of harmonic coefficients use :py:func:`healpy_forward`.
332+
333+ Args:
334+ f (jnp.ndarray): Signal on the sphere.
335+
336+ L (int): Harmonic band-limit. Equivalent to `lmax + 1` in healpy.
337+
338+ nside (int): HEALPix Nside resolution parameter.
339+
340+ Returns:
341+ jnp.ndarray: Harmonic coefficients of signal f.
342+
343+ """
344+ return _healpy_map2alm_p .bind (f , L = L , nside = nside )
345+
346+
347+ def _healpy_alm2map_impl (flm : jnp .ndarray , L : int , nside : int ) -> jnp .ndarray :
348+ return jnp .array (healpy .alm2map (np .array (flm ), lmax = L - 1 , nside = nside ))
349+
350+
351+ def _healpy_alm2map_abstract_eval (
352+ flm : core .ShapedArray , L : int , nside : int
353+ ) -> core .ShapedArray :
354+ return core .ShapedArray (shape = (12 * nside ** 2 ,), dtype = _real_dtype (flm .dtype ))
355+
356+
357+ def _healpy_alm2map_transpose (df : jnp .ndarray , L : int , nside : int ) -> tuple :
358+ scale_factors = (
359+ jnp .concatenate ((jnp .ones (L ), 2 * jnp .ones (L * (L - 1 ) // 2 )))
360+ * (3 * nside ** 2 )
361+ / jnp .pi
294362 )
295- flm_out = flm_out .at [..., : L - 1 ].set (0 )
363+ # Scale factor above includes the inverse quadrature weight given by
364+ # (12 * nside**2) / (4 * jnp.pi) = (3 * nside**2) / jnp.pi
365+ # and also a factor of 2 for m>0 to account for the negative m.
366+ # See explanation in this issue comment:
367+ # https://github.com/astro-informatics/s2fft/issues/243#issuecomment-2500951488
368+ return (scale_factors * jnp .conj (healpy_map2alm (jnp .conj (df ), L , nside )),)
296369
297- return flm_out , None , None
370+
371+ _healpy_alm2map_p = core .Primitive ("healpy_alm2map" )
372+ _healpy_alm2map_p .def_impl (_healpy_alm2map_impl )
373+ _healpy_alm2map_p .def_abstract_eval (_healpy_alm2map_abstract_eval )
374+ ad .deflinear (_healpy_alm2map_p , _healpy_alm2map_transpose )
375+
376+
377+ def healpy_alm2map (flm : jnp .ndarray , L : int , nside : int ) -> jnp .ndarray :
378+ """
379+ JAX wrapper for healpy alm2map function (inverse spherical harmonic transform).
380+
381+ This wrapper assumes the passed spherical harmonic coefficients are a one
382+ dimensional array using HEALPix (ring-ordered) indexing. To instead pass a
383+ two-dimensional array of harmonic coefficients use :py:func:`healpy_inverse`.
384+
385+ Args:
386+ flm (jnp.ndarray): Spherical harmonic coefficients.
387+
388+ L (int): Harmonic band-limit. Equivalent to `lmax + 1` in healpy.
389+
390+ nside (int): HEALPix Nside resolution parameter.
391+
392+ Returns:
393+ jnp.ndarray: Signal on the sphere.
394+
395+ """
396+ return _healpy_alm2map_p .bind (flm , L = L , nside = nside )
298397
299398
300- @custom_vjp
301399def healpy_forward (f : jnp .ndarray , L : int , nside : int , iter : int = 3 ) -> jnp .ndarray :
302400 r"""
303401 Compute the forward scalar spherical harmonic transform (healpy JAX).
304402
305403 HEALPix is a C++ library which implements the scalar spherical harmonic transform
306404 outlined in [1]. We make use of their healpy python bindings for which we provide
307- custom JAX frontends, hence providing support for automatic differentiation. Currently
308- these transforms can only be deployed on CPU, which is a limitation of the C++ library.
405+ custom JAX frontends, hence providing support for automatic differentiation.
406+ Currently these transforms can only be deployed on CPU, which is a limitation of the
407+ C++ library.
309408
310409 Args:
311410 f (jnp.ndarray): Signal on the sphere.
312411
313412 L (int): Harmonic band-limit.
314413
315- nside (int, optional): HEALPix Nside resolution parameter. Only required
316- if sampling="healpix". Defaults to None.
414+ nside (int): HEALPix Nside resolution parameter.
317415
318- iter (int, optional): Number of subiterations for healpy. Note that iterations
319- increase the precision of the forward transform, but reduce the accuracy of
320- the gradient pass. Between 2 and 3 iterations is a good compromise.
416+ iter (int, optional): Number of subiterations (iterative refinement steps) for
417+ healpy. Note that iterations increase the precision of the forward transform
418+ as an inverse of inverse transform, but with a linear increase in
419+ computational cost. Between 2 and 3 iterations is a good compromise.
321420
322421 Returns:
323422 jnp.ndarray: Harmonic coefficients of signal f.
@@ -328,28 +427,39 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda
328427 Astrophysical Journal 622.2 (2005): 759
329428
330429 """
331- flm = jnp .array (healpy .map2alm (np .array (f ), lmax = L - 1 , iter = iter ))
430+ flm = healpy_map2alm (f , L , nside )
431+ for _ in range (iter ):
432+ f_recov = healpy_alm2map (flm , L , nside )
433+ f_error = f - f_recov
434+ flm += healpy_map2alm (f_error , L , nside )
332435 return reindex .flm_hp_to_2d_fast (flm , L )
333436
334437
335- def _healpy_forward_fwd (f : jnp .ndarray , L : int , nside : int , iter : int = 3 ):
336- """Private function which implements the forward pass for forward jax_healpy."""
337- res = ([], L , nside , iter )
338- return healpy_forward (f , L , nside , iter ), res
438+ def healpy_inverse (flm : jnp .ndarray , L : int , nside : int ) -> jnp .ndarray :
439+ r"""
440+ Compute the inverse scalar real spherical harmonic transform (HEALPix JAX).
339441
442+ HEALPix is a C++ library which implements the scalar spherical harmonic transform
443+ outlined in [1]. We make use of their healpy python bindings for which we provide
444+ custom JAX frontends, hence providing support for automatic differentiation.
445+ Currently these transforms can only be deployed on CPU, which is a limitation of the
446+ C++ library.
340447
341- def _healpy_forward_bwd (res , flm ):
342- """Private function which implements the backward pass for forward jax_healpy."""
343- _ , L , nside , _ = res
344- flm_new = reindex .flm_2d_to_hp_fast (flm , L )
345- f = jnp .array (
346- np .conj (healpy .alm2map (np .conj (np .array (flm_new )), lmax = L - 1 , nside = nside ))
347- )
348- return f * (4 * jnp .pi ) / (12 * nside ** 2 ), None , None , None
448+ Args:
449+ flm (jnp.ndarray): Spherical harmonic coefficients.
349450
451+ L (int): Harmonic band-limit.
350452
351- # Link JAX gradients for C backend functions
352- ssht_inverse .defvjp (_ssht_inverse_fwd , _ssht_inverse_bwd )
353- ssht_forward .defvjp (_ssht_forward_fwd , _ssht_forward_bwd )
354- healpy_inverse .defvjp (_healpy_inverse_fwd , _healpy_inverse_bwd )
355- healpy_forward .defvjp (_healpy_forward_fwd , _healpy_forward_bwd )
453+ nside (int): HEALPix Nside resolution parameter.
454+
455+ Returns:
456+ jnp.ndarray: Signal on the sphere.
457+
458+ Note:
459+ [1] Gorski, Krzysztof M., et al. "HEALPix: A framework for high-resolution
460+ discretization and fast analysis of data distributed on the sphere." The
461+ Astrophysical Journal 622.2 (2005): 759
462+
463+ """
464+ flm = reindex .flm_2d_to_hp_fast (flm , L )
465+ return healpy_alm2map (flm , L , nside )
0 commit comments