Skip to content

Commit c5464ec

Browse files
committed
Refactor healpix forward / inverse transforms as custom primitives
1 parent 909e6f1 commit c5464ec

File tree

1 file changed

+171
-66
lines changed

1 file changed

+171
-66
lines changed

s2fft/transforms/c_backend_spherical.py

Lines changed: 171 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
# C backend functions for which to provide JAX frontend.
66
import pyssht
7-
from jax import custom_vjp
7+
from jax import core, custom_vjp
8+
from jax.interpreters import ad
89

910
from s2fft.sampling import reindex
1011
from s2fft.utils import quadrature_jax
@@ -241,83 +242,176 @@ def _ssht_forward_bwd(res, flm):
241242
return f, None, None, None, None, None
242243

243244

244-
@custom_vjp
245-
def healpy_inverse(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
246-
r"""
247-
Compute the inverse scalar real spherical harmonic transform (HEALPix JAX).
245+
# Link JAX gradients for C backend functions
246+
ssht_inverse.defvjp(_ssht_inverse_fwd, _ssht_inverse_bwd)
247+
ssht_forward.defvjp(_ssht_forward_fwd, _ssht_forward_bwd)
248248

249-
HEALPix is a C++ library which implements the scalar spherical harmonic transform
250-
outlined in [1]. We make use of their healpy python bindings for which we provide
251-
custom JAX frontends, hence providing support for automatic differentiation. Currently
252-
these transforms can only be deployed on CPU, which is a limitation of the C++ library.
253249

254-
Args:
255-
flm (jnp.ndarray): Spherical harmonic coefficients.
250+
def _complex_dtype(real_dtype):
251+
"""
252+
Get complex datatype corresponding to a given real datatype.
256253
257-
L (int): Harmonic band-limit.
254+
Derived from https://github.com/jax-ml/jax/blob/1471702adc28/jax/_src/lax/fft.py#L92
258255
259-
nside (int, optional): HEALPix Nside resolution parameter. Only required
260-
if sampling="healpix". Defaults to None.
256+
Original license:
261257
262-
Returns:
263-
jnp.ndarray: Signal on the sphere.
258+
Copyright 2019 The JAX Authors.
264259
265-
Note:
266-
[1] Gorski, Krzysztof M., et al. "HEALPix: A framework for high-resolution
267-
discretization and fast analysis of data distributed on the sphere." The
268-
Astrophysical Journal 622.2 (2005): 759
260+
Licensed under the Apache License, Version 2.0 (the "License");
261+
you may not use this file except in compliance with the License.
262+
You may obtain a copy of the License at
263+
264+
https://www.apache.org/licenses/LICENSE-2.0
269265
266+
Unless required by applicable law or agreed to in writing, software
267+
distributed under the License is distributed on an "AS IS" BASIS,
268+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
269+
See the License for the specific language governing permissions and
270+
limitations under the License.
270271
"""
271-
flm = reindex.flm_2d_to_hp_fast(flm, L)
272-
f = jnp.array(healpy.alm2map(np.array(flm), lmax=L - 1, nside=nside))
273-
return f
272+
return (np.zeros((), real_dtype) + np.zeros((), np.complex64)).dtype
273+
274+
275+
def _real_dtype(complex_dtype):
276+
"""
277+
Get real datatype corresponding to a given complex datatype.
278+
279+
Derived from https://github.com/jax-ml/jax/blob/1471702adc28/jax/_src/lax/fft.py#L93
280+
281+
Original license:
282+
283+
Copyright 2019 The JAX Authors.
284+
285+
Licensed under the Apache License, Version 2.0 (the "License");
286+
you may not use this file except in compliance with the License.
287+
You may obtain a copy of the License at
288+
289+
https://www.apache.org/licenses/LICENSE-2.0
274290
291+
Unless required by applicable law or agreed to in writing, software
292+
distributed under the License is distributed on an "AS IS" BASIS,
293+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
294+
See the License for the specific language governing permissions and
295+
limitations under the License.
296+
"""
297+
return np.finfo(complex_dtype).dtype
298+
299+
300+
def _healpy_map2alm_impl(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
301+
return jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=0))
275302

276-
def _healpy_inverse_fwd(flm: jnp.ndarray, L: int, nside: int):
277-
"""Private function which implements the forward pass for inverse jax_healpy."""
278-
res = ([], L, nside)
279-
return healpy_inverse(flm, L, nside), res
280303

304+
def _healpy_map2alm_abstract_eval(
305+
f: core.ShapedArray, L: int, nside: int
306+
) -> core.ShapedArray:
307+
return core.ShapedArray(shape=(L * (L + 1) // 2,), dtype=_complex_dtype(f.dtype))
281308

282-
def _healpy_inverse_bwd(res, f):
283-
"""Private function which implements the backward pass for inverse jax_healpy."""
284-
_, L, nside = res
285-
f_new = f * (12 * nside**2) / (4 * jnp.pi)
286-
flm_out = jnp.array(
287-
np.conj(healpy.map2alm(np.conj(np.array(f_new)), lmax=L - 1, iter=0))
309+
310+
def _healpy_map2alm_transpose(dflm: jnp.ndarray, L: int, nside: int):
311+
scale_factors = (
312+
jnp.concatenate((jnp.ones(L), 2 * jnp.ones(L * (L - 1) // 2)))
313+
* (3 * nside**2)
314+
/ jnp.pi
288315
)
289-
# iter MUST be zero otherwise gradient propagation is incorrect (JDM).
290-
flm_out = reindex.flm_hp_to_2d_fast(flm_out, L)
291-
m_conj = (-1) ** (jnp.arange(1, L) % 2)
292-
flm_out = flm_out.at[..., L:].add(
293-
jnp.flip(m_conj * jnp.conj(flm_out[..., : L - 1]), axis=-1)
316+
return (jnp.conj(healpy_alm2map(jnp.conj(dflm) / scale_factors, L, nside)),)
317+
318+
319+
_healpy_map2alm_p = core.Primitive("healpy_map2alm")
320+
_healpy_map2alm_p.def_impl(_healpy_map2alm_impl)
321+
_healpy_map2alm_p.def_abstract_eval(_healpy_map2alm_abstract_eval)
322+
ad.deflinear(_healpy_map2alm_p, _healpy_map2alm_transpose)
323+
324+
325+
def healpy_map2alm(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
326+
"""
327+
JAX wrapper for healpy.map2alm function (forward spherical harmonic transform).
328+
329+
This wrapper will return the spherical harmonic coefficients as a one dimensional
330+
array using HEALPix (ring-ordered) indexing. To instead return a two-dimensional
331+
array of harmonic coefficients use :py:func:`healpy_forward`.
332+
333+
Args:
334+
f (jnp.ndarray): Signal on the sphere.
335+
336+
L (int): Harmonic band-limit. Equivalent to `lmax + 1` in healpy.
337+
338+
nside (int): HEALPix Nside resolution parameter.
339+
340+
Returns:
341+
jnp.ndarray: Harmonic coefficients of signal f.
342+
343+
"""
344+
return _healpy_map2alm_p.bind(f, L=L, nside=nside)
345+
346+
347+
def _healpy_alm2map_impl(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
348+
return jnp.array(healpy.alm2map(np.array(flm), lmax=L - 1, nside=nside))
349+
350+
351+
def _healpy_alm2map_abstract_eval(
352+
flm: core.ShapedArray, L: int, nside: int
353+
) -> core.ShapedArray:
354+
return core.ShapedArray(shape=(12 * nside**2,), dtype=_real_dtype(flm.dtype))
355+
356+
357+
def _healpy_alm2map_transpose(df: jnp.ndarray, L: int, nside: int) -> tuple:
358+
scale_factors = (
359+
jnp.concatenate((jnp.ones(L), 2 * jnp.ones(L * (L - 1) // 2)))
360+
* (3 * nside**2)
361+
/ jnp.pi
294362
)
295-
flm_out = flm_out.at[..., : L - 1].set(0)
363+
return (scale_factors * jnp.conj(healpy_map2alm(jnp.conj(df), L, nside)),)
296364

297-
return flm_out, None, None
365+
366+
_healpy_alm2map_p = core.Primitive("healpy_alm2map")
367+
_healpy_alm2map_p.def_impl(_healpy_alm2map_impl)
368+
_healpy_alm2map_p.def_abstract_eval(_healpy_alm2map_abstract_eval)
369+
ad.deflinear(_healpy_alm2map_p, _healpy_alm2map_transpose)
370+
371+
372+
def healpy_alm2map(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
373+
"""
374+
JAX wrapper for healpy.alm2map function (inverse spherical harmonic transform).
375+
376+
This wrapper assumes the passed spherical harmonic coefficients are a one
377+
dimensional array using HEALPix (ring-ordered) indexing. To instead pass a
378+
two-dimensional array of harmonic coefficients use :py:func:`healpy_inverse`.
379+
380+
Args:
381+
flm (jnp.ndarray): Spherical harmonic coefficients.
382+
383+
L (int): Harmonic band-limit. Equivalent to `lmax + 1` in healpy.
384+
385+
nside (int): HEALPix Nside resolution parameter.
386+
387+
Returns:
388+
jnp.ndarray: Signal on the sphere.
389+
390+
"""
391+
return _healpy_alm2map_p.bind(flm, L=L, nside=nside)
298392

299393

300-
@custom_vjp
301394
def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.ndarray:
302395
r"""
303396
Compute the forward scalar spherical harmonic transform (healpy JAX).
304397
305398
HEALPix is a C++ library which implements the scalar spherical harmonic transform
306399
outlined in [1]. We make use of their healpy python bindings for which we provide
307-
custom JAX frontends, hence providing support for automatic differentiation. Currently
308-
these transforms can only be deployed on CPU, which is a limitation of the C++ library.
400+
custom JAX frontends, hence providing support for automatic differentiation.
401+
Currently these transforms can only be deployed on CPU, which is a limitation of the
402+
C++ library.
309403
310404
Args:
311405
f (jnp.ndarray): Signal on the sphere.
312406
313407
L (int): Harmonic band-limit.
314408
315-
nside (int, optional): HEALPix Nside resolution parameter. Only required
316-
if sampling="healpix". Defaults to None.
409+
nside (int): HEALPix Nside resolution parameter.
317410
318-
iter (int, optional): Number of subiterations for healpy. Note that iterations
319-
increase the precision of the forward transform, but reduce the accuracy of
320-
the gradient pass. Between 2 and 3 iterations is a good compromise.
411+
iter (int, optional): Number of subiterations (iterative refinement steps) for
412+
healpy. Note that iterations increase the precision of the forward transform
413+
as an inverse of inverse transform, but with a linear increase in
414+
computational cost. Between 2 and 3 iterations is a good compromise.
321415
322416
Returns:
323417
jnp.ndarray: Harmonic coefficients of signal f.
@@ -328,28 +422,39 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda
328422
Astrophysical Journal 622.2 (2005): 759
329423
330424
"""
331-
flm = jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=iter))
425+
flm = healpy_map2alm(f, L, nside)
426+
for _ in range(iter):
427+
f_recov = healpy_alm2map(flm, L, nside)
428+
f_error = f - f_recov
429+
flm += healpy_map2alm(f_error, L, nside)
332430
return reindex.flm_hp_to_2d_fast(flm, L)
333431

334432

335-
def _healpy_forward_fwd(f: jnp.ndarray, L: int, nside: int, iter: int = 3):
336-
"""Private function which implements the forward pass for forward jax_healpy."""
337-
res = ([], L, nside, iter)
338-
return healpy_forward(f, L, nside, iter), res
433+
def healpy_inverse(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
434+
r"""
435+
Compute the inverse scalar real spherical harmonic transform (HEALPix JAX).
339436
437+
HEALPix is a C++ library which implements the scalar spherical harmonic transform
438+
outlined in [1]. We make use of their healpy python bindings for which we provide
439+
custom JAX frontends, hence providing support for automatic differentiation.
440+
Currently these transforms can only be deployed on CPU, which is a limitation of the
441+
C++ library.
340442
341-
def _healpy_forward_bwd(res, flm):
342-
"""Private function which implements the backward pass for forward jax_healpy."""
343-
_, L, nside, _ = res
344-
flm_new = reindex.flm_2d_to_hp_fast(flm, L)
345-
f = jnp.array(
346-
np.conj(healpy.alm2map(np.conj(np.array(flm_new)), lmax=L - 1, nside=nside))
347-
)
348-
return f * (4 * jnp.pi) / (12 * nside**2), None, None, None
443+
Args:
444+
flm (jnp.ndarray): Spherical harmonic coefficients.
349445
446+
L (int): Harmonic band-limit.
350447
351-
# Link JAX gradients for C backend functions
352-
ssht_inverse.defvjp(_ssht_inverse_fwd, _ssht_inverse_bwd)
353-
ssht_forward.defvjp(_ssht_forward_fwd, _ssht_forward_bwd)
354-
healpy_inverse.defvjp(_healpy_inverse_fwd, _healpy_inverse_bwd)
355-
healpy_forward.defvjp(_healpy_forward_fwd, _healpy_forward_bwd)
448+
nside (int): HEALPix Nside resolution parameter.
449+
450+
Returns:
451+
jnp.ndarray: Signal on the sphere.
452+
453+
Note:
454+
[1] Gorski, Krzysztof M., et al. "HEALPix: A framework for high-resolution
455+
discretization and fast analysis of data distributed on the sphere." The
456+
Astrophysical Journal 622.2 (2005): 759
457+
458+
"""
459+
flm = reindex.flm_2d_to_hp_fast(flm, L)
460+
return healpy_alm2map(flm, L, nside)

0 commit comments

Comments
 (0)