44
55# C backend functions for which to provide JAX frontend.
66import pyssht
7- from jax import custom_vjp
7+ from jax import core , custom_vjp
8+ from jax .interpreters import ad
89
910from s2fft .sampling import reindex
1011from s2fft .utils import quadrature_jax
@@ -241,83 +242,176 @@ def _ssht_forward_bwd(res, flm):
241242 return f , None , None , None , None , None
242243
243244
244- @custom_vjp
245- def healpy_inverse (flm : jnp .ndarray , L : int , nside : int ) -> jnp .ndarray :
246- r"""
247- Compute the inverse scalar real spherical harmonic transform (HEALPix JAX).
245+ # Link JAX gradients for C backend functions
246+ ssht_inverse .defvjp (_ssht_inverse_fwd , _ssht_inverse_bwd )
247+ ssht_forward .defvjp (_ssht_forward_fwd , _ssht_forward_bwd )
248248
249- HEALPix is a C++ library which implements the scalar spherical harmonic transform
250- outlined in [1]. We make use of their healpy python bindings for which we provide
251- custom JAX frontends, hence providing support for automatic differentiation. Currently
252- these transforms can only be deployed on CPU, which is a limitation of the C++ library.
253249
254- Args:
255- flm (jnp.ndarray): Spherical harmonic coefficients.
250+ def _complex_dtype (real_dtype ):
251+ """
252+ Get complex datatype corresponding to a given real datatype.
256253
257- L (int): Harmonic band-limit.
254+ Derived from https://github.com/jax-ml/jax/blob/1471702adc28/jax/_src/lax/fft.py#L92
258255
259- nside (int, optional): HEALPix Nside resolution parameter. Only required
260- if sampling="healpix". Defaults to None.
256+ Original license:
261257
262- Returns:
263- jnp.ndarray: Signal on the sphere.
258+ Copyright 2019 The JAX Authors.
264259
265- Note:
266- [1] Gorski, Krzysztof M., et al. "HEALPix: A framework for high-resolution
267- discretization and fast analysis of data distributed on the sphere." The
268- Astrophysical Journal 622.2 (2005): 759
260+ Licensed under the Apache License, Version 2.0 (the "License");
261+ you may not use this file except in compliance with the License.
262+ You may obtain a copy of the License at
263+
264+ https://www.apache.org/licenses/LICENSE-2.0
269265
266+ Unless required by applicable law or agreed to in writing, software
267+ distributed under the License is distributed on an "AS IS" BASIS,
268+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
269+ See the License for the specific language governing permissions and
270+ limitations under the License.
270271 """
271- flm = reindex .flm_2d_to_hp_fast (flm , L )
272- f = jnp .array (healpy .alm2map (np .array (flm ), lmax = L - 1 , nside = nside ))
273- return f
272+ return (np .zeros ((), real_dtype ) + np .zeros ((), np .complex64 )).dtype
273+
274+
275+ def _real_dtype (complex_dtype ):
276+ """
277+ Get real datatype corresponding to a given complex datatype.
278+
279+ Derived from https://github.com/jax-ml/jax/blob/1471702adc28/jax/_src/lax/fft.py#L93
280+
281+ Original license:
282+
283+ Copyright 2019 The JAX Authors.
284+
285+ Licensed under the Apache License, Version 2.0 (the "License");
286+ you may not use this file except in compliance with the License.
287+ You may obtain a copy of the License at
288+
289+ https://www.apache.org/licenses/LICENSE-2.0
274290
291+ Unless required by applicable law or agreed to in writing, software
292+ distributed under the License is distributed on an "AS IS" BASIS,
293+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
294+ See the License for the specific language governing permissions and
295+ limitations under the License.
296+ """
297+ return np .finfo (complex_dtype ).dtype
298+
299+
300+ def _healpy_map2alm_impl (f : jnp .ndarray , L : int , nside : int ) -> jnp .ndarray :
301+ return jnp .array (healpy .map2alm (np .array (f ), lmax = L - 1 , iter = 0 ))
275302
276- def _healpy_inverse_fwd (flm : jnp .ndarray , L : int , nside : int ):
277- """Private function which implements the forward pass for inverse jax_healpy."""
278- res = ([], L , nside )
279- return healpy_inverse (flm , L , nside ), res
280303
304+ def _healpy_map2alm_abstract_eval (
305+ f : core .ShapedArray , L : int , nside : int
306+ ) -> core .ShapedArray :
307+ return core .ShapedArray (shape = (L * (L + 1 ) // 2 ,), dtype = _complex_dtype (f .dtype ))
281308
282- def _healpy_inverse_bwd ( res , f ):
283- """Private function which implements the backward pass for inverse jax_healpy."""
284- _ , L , nside = res
285- f_new = f * ( 12 * nside ** 2 ) / ( 4 * jnp . pi )
286- flm_out = jnp . array (
287- np . conj ( healpy . map2alm ( np . conj ( np . array ( f_new )), lmax = L - 1 , iter = 0 ))
309+
310+ def _healpy_map2alm_transpose ( dflm : jnp . ndarray , L : int , nside : int ):
311+ scale_factors = (
312+ jnp . concatenate (( jnp . ones ( L ), 2 * jnp . ones ( L * ( L - 1 ) // 2 )) )
313+ * ( 3 * nside ** 2 )
314+ / jnp . pi
288315 )
289- # iter MUST be zero otherwise gradient propagation is incorrect (JDM).
290- flm_out = reindex .flm_hp_to_2d_fast (flm_out , L )
291- m_conj = (- 1 ) ** (jnp .arange (1 , L ) % 2 )
292- flm_out = flm_out .at [..., L :].add (
293- jnp .flip (m_conj * jnp .conj (flm_out [..., : L - 1 ]), axis = - 1 )
316+ return (jnp .conj (healpy_alm2map (jnp .conj (dflm ) / scale_factors , L , nside )),)
317+
318+
319+ _healpy_map2alm_p = core .Primitive ("healpy_map2alm" )
320+ _healpy_map2alm_p .def_impl (_healpy_map2alm_impl )
321+ _healpy_map2alm_p .def_abstract_eval (_healpy_map2alm_abstract_eval )
322+ ad .deflinear (_healpy_map2alm_p , _healpy_map2alm_transpose )
323+
324+
325+ def healpy_map2alm (f : jnp .ndarray , L : int , nside : int ) -> jnp .ndarray :
326+ """
327+ JAX wrapper for healpy.map2alm function (forward spherical harmonic transform).
328+
329+ This wrapper will return the spherical harmonic coefficients as a one dimensional
330+ array using HEALPix (ring-ordered) indexing. To instead return a two-dimensional
331+ array of harmonic coefficients use :py:func:`healpy_forward`.
332+
333+ Args:
334+ f (jnp.ndarray): Signal on the sphere.
335+
336+ L (int): Harmonic band-limit. Equivalent to `lmax + 1` in healpy.
337+
338+ nside (int): HEALPix Nside resolution parameter.
339+
340+ Returns:
341+ jnp.ndarray: Harmonic coefficients of signal f.
342+
343+ """
344+ return _healpy_map2alm_p .bind (f , L = L , nside = nside )
345+
346+
347+ def _healpy_alm2map_impl (flm : jnp .ndarray , L : int , nside : int ) -> jnp .ndarray :
348+ return jnp .array (healpy .alm2map (np .array (flm ), lmax = L - 1 , nside = nside ))
349+
350+
351+ def _healpy_alm2map_abstract_eval (
352+ flm : core .ShapedArray , L : int , nside : int
353+ ) -> core .ShapedArray :
354+ return core .ShapedArray (shape = (12 * nside ** 2 ,), dtype = _real_dtype (flm .dtype ))
355+
356+
357+ def _healpy_alm2map_transpose (df : jnp .ndarray , L : int , nside : int ) -> tuple :
358+ scale_factors = (
359+ jnp .concatenate ((jnp .ones (L ), 2 * jnp .ones (L * (L - 1 ) // 2 )))
360+ * (3 * nside ** 2 )
361+ / jnp .pi
294362 )
295- flm_out = flm_out . at [..., : L - 1 ]. set ( 0 )
363+ return ( scale_factors * jnp . conj ( healpy_map2alm ( jnp . conj ( df ), L , nside )), )
296364
297- return flm_out , None , None
365+
366+ _healpy_alm2map_p = core .Primitive ("healpy_alm2map" )
367+ _healpy_alm2map_p .def_impl (_healpy_alm2map_impl )
368+ _healpy_alm2map_p .def_abstract_eval (_healpy_alm2map_abstract_eval )
369+ ad .deflinear (_healpy_alm2map_p , _healpy_alm2map_transpose )
370+
371+
372+ def healpy_alm2map (flm : jnp .ndarray , L : int , nside : int ) -> jnp .ndarray :
373+ """
374+ JAX wrapper for healpy.alm2map function (inverse spherical harmonic transform).
375+
376+ This wrapper assumes the passed spherical harmonic coefficients are a one
377+ dimensional array using HEALPix (ring-ordered) indexing. To instead pass a
378+ two-dimensional array of harmonic coefficients use :py:func:`healpy_inverse`.
379+
380+ Args:
381+ flm (jnp.ndarray): Spherical harmonic coefficients.
382+
383+ L (int): Harmonic band-limit. Equivalent to `lmax + 1` in healpy.
384+
385+ nside (int): HEALPix Nside resolution parameter.
386+
387+ Returns:
388+ jnp.ndarray: Signal on the sphere.
389+
390+ """
391+ return _healpy_alm2map_p .bind (flm , L = L , nside = nside )
298392
299393
300- @custom_vjp
301394def healpy_forward (f : jnp .ndarray , L : int , nside : int , iter : int = 3 ) -> jnp .ndarray :
302395 r"""
303396 Compute the forward scalar spherical harmonic transform (healpy JAX).
304397
305398 HEALPix is a C++ library which implements the scalar spherical harmonic transform
306399 outlined in [1]. We make use of their healpy python bindings for which we provide
307- custom JAX frontends, hence providing support for automatic differentiation. Currently
308- these transforms can only be deployed on CPU, which is a limitation of the C++ library.
400+ custom JAX frontends, hence providing support for automatic differentiation.
401+ Currently these transforms can only be deployed on CPU, which is a limitation of the
402+ C++ library.
309403
310404 Args:
311405 f (jnp.ndarray): Signal on the sphere.
312406
313407 L (int): Harmonic band-limit.
314408
315- nside (int, optional): HEALPix Nside resolution parameter. Only required
316- if sampling="healpix". Defaults to None.
409+ nside (int): HEALPix Nside resolution parameter.
317410
318- iter (int, optional): Number of subiterations for healpy. Note that iterations
319- increase the precision of the forward transform, but reduce the accuracy of
320- the gradient pass. Between 2 and 3 iterations is a good compromise.
411+ iter (int, optional): Number of subiterations (iterative refinement steps) for
412+ healpy. Note that iterations increase the precision of the forward transform
413+ as an inverse of inverse transform, but with a linear increase in
414+ computational cost. Between 2 and 3 iterations is a good compromise.
321415
322416 Returns:
323417 jnp.ndarray: Harmonic coefficients of signal f.
@@ -328,28 +422,39 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda
328422 Astrophysical Journal 622.2 (2005): 759
329423
330424 """
331- flm = jnp .array (healpy .map2alm (np .array (f ), lmax = L - 1 , iter = iter ))
425+ flm = healpy_map2alm (f , L , nside )
426+ for _ in range (iter ):
427+ f_recov = healpy_alm2map (flm , L , nside )
428+ f_error = f - f_recov
429+ flm += healpy_map2alm (f_error , L , nside )
332430 return reindex .flm_hp_to_2d_fast (flm , L )
333431
334432
335- def _healpy_forward_fwd (f : jnp .ndarray , L : int , nside : int , iter : int = 3 ):
336- """Private function which implements the forward pass for forward jax_healpy."""
337- res = ([], L , nside , iter )
338- return healpy_forward (f , L , nside , iter ), res
433+ def healpy_inverse (flm : jnp .ndarray , L : int , nside : int ) -> jnp .ndarray :
434+ r"""
435+ Compute the inverse scalar real spherical harmonic transform (HEALPix JAX).
339436
437+ HEALPix is a C++ library which implements the scalar spherical harmonic transform
438+ outlined in [1]. We make use of their healpy python bindings for which we provide
439+ custom JAX frontends, hence providing support for automatic differentiation.
440+ Currently these transforms can only be deployed on CPU, which is a limitation of the
441+ C++ library.
340442
341- def _healpy_forward_bwd (res , flm ):
342- """Private function which implements the backward pass for forward jax_healpy."""
343- _ , L , nside , _ = res
344- flm_new = reindex .flm_2d_to_hp_fast (flm , L )
345- f = jnp .array (
346- np .conj (healpy .alm2map (np .conj (np .array (flm_new )), lmax = L - 1 , nside = nside ))
347- )
348- return f * (4 * jnp .pi ) / (12 * nside ** 2 ), None , None , None
443+ Args:
444+ flm (jnp.ndarray): Spherical harmonic coefficients.
349445
446+ L (int): Harmonic band-limit.
350447
351- # Link JAX gradients for C backend functions
352- ssht_inverse .defvjp (_ssht_inverse_fwd , _ssht_inverse_bwd )
353- ssht_forward .defvjp (_ssht_forward_fwd , _ssht_forward_bwd )
354- healpy_inverse .defvjp (_healpy_inverse_fwd , _healpy_inverse_bwd )
355- healpy_forward .defvjp (_healpy_forward_fwd , _healpy_forward_bwd )
448+ nside (int): HEALPix Nside resolution parameter.
449+
450+ Returns:
451+ jnp.ndarray: Signal on the sphere.
452+
453+ Note:
454+ [1] Gorski, Krzysztof M., et al. "HEALPix: A framework for high-resolution
455+ discretization and fast analysis of data distributed on the sphere." The
456+ Astrophysical Journal 622.2 (2005): 759
457+
458+ """
459+ flm = reindex .flm_2d_to_hp_fast (flm , L )
460+ return healpy_alm2map (flm , L , nside )
0 commit comments