Skip to content

Commit dc9b2bc

Browse files
authored
Merge pull request #244 from astro-informatics/mmg/healpix-gradient-fix
Correct `healpix_forward` derivatives and add support for forward and higher order autodiff
2 parents 5210481 + fe82eaa commit dc9b2bc

File tree

2 files changed

+180
-80
lines changed

2 files changed

+180
-80
lines changed

s2fft/transforms/c_backend_spherical.py

Lines changed: 176 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
# C backend functions for which to provide JAX frontend.
66
import pyssht
7-
from jax import custom_vjp
7+
from jax import core, custom_vjp
8+
from jax.interpreters import ad
89

910
from s2fft.sampling import reindex
1011
from s2fft.utils import quadrature_jax
@@ -241,83 +242,181 @@ def _ssht_forward_bwd(res, flm):
241242
return f, None, None, None, None, None
242243

243244

244-
@custom_vjp
245-
def healpy_inverse(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
246-
r"""
247-
Compute the inverse scalar real spherical harmonic transform (HEALPix JAX).
245+
# Link JAX gradients for C backend functions
246+
ssht_inverse.defvjp(_ssht_inverse_fwd, _ssht_inverse_bwd)
247+
ssht_forward.defvjp(_ssht_forward_fwd, _ssht_forward_bwd)
248248

249-
HEALPix is a C++ library which implements the scalar spherical harmonic transform
250-
outlined in [1]. We make use of their healpy python bindings for which we provide
251-
custom JAX frontends, hence providing support for automatic differentiation. Currently
252-
these transforms can only be deployed on CPU, which is a limitation of the C++ library.
253249

254-
Args:
255-
flm (jnp.ndarray): Spherical harmonic coefficients.
250+
def _complex_dtype(real_dtype):
251+
"""
252+
Get complex datatype corresponding to a given real datatype.
256253
257-
L (int): Harmonic band-limit.
254+
Derived from https://github.com/jax-ml/jax/blob/1471702adc28/jax/_src/lax/fft.py#L92
258255
259-
nside (int, optional): HEALPix Nside resolution parameter. Only required
260-
if sampling="healpix". Defaults to None.
256+
Original license:
261257
262-
Returns:
263-
jnp.ndarray: Signal on the sphere.
258+
Copyright 2019 The JAX Authors.
264259
265-
Note:
266-
[1] Gorski, Krzysztof M., et al. "HEALPix: A framework for high-resolution
267-
discretization and fast analysis of data distributed on the sphere." The
268-
Astrophysical Journal 622.2 (2005): 759
260+
Licensed under the Apache License, Version 2.0 (the "License");
261+
you may not use this file except in compliance with the License.
262+
You may obtain a copy of the License at
263+
264+
https://www.apache.org/licenses/LICENSE-2.0
269265
266+
Unless required by applicable law or agreed to in writing, software
267+
distributed under the License is distributed on an "AS IS" BASIS,
268+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
269+
See the License for the specific language governing permissions and
270+
limitations under the License.
270271
"""
271-
flm = reindex.flm_2d_to_hp_fast(flm, L)
272-
f = jnp.array(healpy.alm2map(np.array(flm), lmax=L - 1, nside=nside))
273-
return f
272+
return (np.zeros((), real_dtype) + np.zeros((), np.complex64)).dtype
273+
274+
275+
def _real_dtype(complex_dtype):
276+
"""
277+
Get real datatype corresponding to a given complex datatype.
278+
279+
Derived from https://github.com/jax-ml/jax/blob/1471702adc28/jax/_src/lax/fft.py#L93
280+
281+
Original license:
282+
283+
Copyright 2019 The JAX Authors.
284+
285+
Licensed under the Apache License, Version 2.0 (the "License");
286+
you may not use this file except in compliance with the License.
287+
You may obtain a copy of the License at
288+
289+
https://www.apache.org/licenses/LICENSE-2.0
274290
291+
Unless required by applicable law or agreed to in writing, software
292+
distributed under the License is distributed on an "AS IS" BASIS,
293+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
294+
See the License for the specific language governing permissions and
295+
limitations under the License.
296+
"""
297+
return np.finfo(complex_dtype).dtype
298+
299+
300+
def _healpy_map2alm_impl(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
301+
return jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=0))
275302

276-
def _healpy_inverse_fwd(flm: jnp.ndarray, L: int, nside: int):
277-
"""Private function which implements the forward pass for inverse jax_healpy."""
278-
res = ([], L, nside)
279-
return healpy_inverse(flm, L, nside), res
280303

304+
def _healpy_map2alm_abstract_eval(
305+
f: core.ShapedArray, L: int, nside: int
306+
) -> core.ShapedArray:
307+
return core.ShapedArray(shape=(L * (L + 1) // 2,), dtype=_complex_dtype(f.dtype))
281308

282-
def _healpy_inverse_bwd(res, f):
283-
"""Private function which implements the backward pass for inverse jax_healpy."""
284-
_, L, nside = res
285-
f_new = f * (12 * nside**2) / (4 * jnp.pi)
286-
flm_out = jnp.array(
287-
np.conj(healpy.map2alm(np.conj(np.array(f_new)), lmax=L - 1, iter=0))
309+
310+
def _healpy_map2alm_transpose(dflm: jnp.ndarray, L: int, nside: int):
311+
scale_factors = (
312+
jnp.concatenate((jnp.ones(L), 2 * jnp.ones(L * (L - 1) // 2)))
313+
* (3 * nside**2)
314+
/ jnp.pi
288315
)
289-
# iter MUST be zero otherwise gradient propagation is incorrect (JDM).
290-
flm_out = reindex.flm_hp_to_2d_fast(flm_out, L)
291-
m_conj = (-1) ** (jnp.arange(1, L) % 2)
292-
flm_out = flm_out.at[..., L:].add(
293-
jnp.flip(m_conj * jnp.conj(flm_out[..., : L - 1]), axis=-1)
316+
return (jnp.conj(healpy_alm2map(jnp.conj(dflm) / scale_factors, L, nside)),)
317+
318+
319+
_healpy_map2alm_p = core.Primitive("healpy_map2alm")
320+
_healpy_map2alm_p.def_impl(_healpy_map2alm_impl)
321+
_healpy_map2alm_p.def_abstract_eval(_healpy_map2alm_abstract_eval)
322+
ad.deflinear(_healpy_map2alm_p, _healpy_map2alm_transpose)
323+
324+
325+
def healpy_map2alm(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
326+
"""
327+
JAX wrapper for healpy map2alm function (forward spherical harmonic transform).
328+
329+
This wrapper will return the spherical harmonic coefficients as a one dimensional
330+
array using HEALPix (ring-ordered) indexing. To instead return a two-dimensional
331+
array of harmonic coefficients use :py:func:`healpy_forward`.
332+
333+
Args:
334+
f (jnp.ndarray): Signal on the sphere.
335+
336+
L (int): Harmonic band-limit. Equivalent to `lmax + 1` in healpy.
337+
338+
nside (int): HEALPix Nside resolution parameter.
339+
340+
Returns:
341+
jnp.ndarray: Harmonic coefficients of signal f.
342+
343+
"""
344+
return _healpy_map2alm_p.bind(f, L=L, nside=nside)
345+
346+
347+
def _healpy_alm2map_impl(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
348+
return jnp.array(healpy.alm2map(np.array(flm), lmax=L - 1, nside=nside))
349+
350+
351+
def _healpy_alm2map_abstract_eval(
352+
flm: core.ShapedArray, L: int, nside: int
353+
) -> core.ShapedArray:
354+
return core.ShapedArray(shape=(12 * nside**2,), dtype=_real_dtype(flm.dtype))
355+
356+
357+
def _healpy_alm2map_transpose(df: jnp.ndarray, L: int, nside: int) -> tuple:
358+
scale_factors = (
359+
jnp.concatenate((jnp.ones(L), 2 * jnp.ones(L * (L - 1) // 2)))
360+
* (3 * nside**2)
361+
/ jnp.pi
294362
)
295-
flm_out = flm_out.at[..., : L - 1].set(0)
363+
# Scale factor above includes the inverse quadrature weight given by
364+
# (12 * nside**2) / (4 * jnp.pi) = (3 * nside**2) / jnp.pi
365+
# and also a factor of 2 for m>0 to account for the negative m.
366+
# See explanation in this issue comment:
367+
# https://github.com/astro-informatics/s2fft/issues/243#issuecomment-2500951488
368+
return (scale_factors * jnp.conj(healpy_map2alm(jnp.conj(df), L, nside)),)
296369

297-
return flm_out, None, None
370+
371+
_healpy_alm2map_p = core.Primitive("healpy_alm2map")
372+
_healpy_alm2map_p.def_impl(_healpy_alm2map_impl)
373+
_healpy_alm2map_p.def_abstract_eval(_healpy_alm2map_abstract_eval)
374+
ad.deflinear(_healpy_alm2map_p, _healpy_alm2map_transpose)
375+
376+
377+
def healpy_alm2map(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
378+
"""
379+
JAX wrapper for healpy alm2map function (inverse spherical harmonic transform).
380+
381+
This wrapper assumes the passed spherical harmonic coefficients are a one
382+
dimensional array using HEALPix (ring-ordered) indexing. To instead pass a
383+
two-dimensional array of harmonic coefficients use :py:func:`healpy_inverse`.
384+
385+
Args:
386+
flm (jnp.ndarray): Spherical harmonic coefficients.
387+
388+
L (int): Harmonic band-limit. Equivalent to `lmax + 1` in healpy.
389+
390+
nside (int): HEALPix Nside resolution parameter.
391+
392+
Returns:
393+
jnp.ndarray: Signal on the sphere.
394+
395+
"""
396+
return _healpy_alm2map_p.bind(flm, L=L, nside=nside)
298397

299398

300-
@custom_vjp
301399
def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.ndarray:
302400
r"""
303401
Compute the forward scalar spherical harmonic transform (healpy JAX).
304402
305403
HEALPix is a C++ library which implements the scalar spherical harmonic transform
306404
outlined in [1]. We make use of their healpy python bindings for which we provide
307-
custom JAX frontends, hence providing support for automatic differentiation. Currently
308-
these transforms can only be deployed on CPU, which is a limitation of the C++ library.
405+
custom JAX frontends, hence providing support for automatic differentiation.
406+
Currently these transforms can only be deployed on CPU, which is a limitation of the
407+
C++ library.
309408
310409
Args:
311410
f (jnp.ndarray): Signal on the sphere.
312411
313412
L (int): Harmonic band-limit.
314413
315-
nside (int, optional): HEALPix Nside resolution parameter. Only required
316-
if sampling="healpix". Defaults to None.
414+
nside (int): HEALPix Nside resolution parameter.
317415
318-
iter (int, optional): Number of subiterations for healpy. Note that iterations
319-
increase the precision of the forward transform, but reduce the accuracy of
320-
the gradient pass. Between 2 and 3 iterations is a good compromise.
416+
iter (int, optional): Number of subiterations (iterative refinement steps) for
417+
healpy. Note that iterations increase the precision of the forward transform
418+
as an inverse of inverse transform, but with a linear increase in
419+
computational cost. Between 2 and 3 iterations is a good compromise.
321420
322421
Returns:
323422
jnp.ndarray: Harmonic coefficients of signal f.
@@ -328,28 +427,39 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda
328427
Astrophysical Journal 622.2 (2005): 759
329428
330429
"""
331-
flm = jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=iter))
430+
flm = healpy_map2alm(f, L, nside)
431+
for _ in range(iter):
432+
f_recov = healpy_alm2map(flm, L, nside)
433+
f_error = f - f_recov
434+
flm += healpy_map2alm(f_error, L, nside)
332435
return reindex.flm_hp_to_2d_fast(flm, L)
333436

334437

335-
def _healpy_forward_fwd(f: jnp.ndarray, L: int, nside: int, iter: int = 3):
336-
"""Private function which implements the forward pass for forward jax_healpy."""
337-
res = ([], L, nside, iter)
338-
return healpy_forward(f, L, nside, iter), res
438+
def healpy_inverse(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
439+
r"""
440+
Compute the inverse scalar real spherical harmonic transform (HEALPix JAX).
339441
442+
HEALPix is a C++ library which implements the scalar spherical harmonic transform
443+
outlined in [1]. We make use of their healpy python bindings for which we provide
444+
custom JAX frontends, hence providing support for automatic differentiation.
445+
Currently these transforms can only be deployed on CPU, which is a limitation of the
446+
C++ library.
340447
341-
def _healpy_forward_bwd(res, flm):
342-
"""Private function which implements the backward pass for forward jax_healpy."""
343-
_, L, nside, _ = res
344-
flm_new = reindex.flm_2d_to_hp_fast(flm, L)
345-
f = jnp.array(
346-
np.conj(healpy.alm2map(np.conj(np.array(flm_new)), lmax=L - 1, nside=nside))
347-
)
348-
return f * (4 * jnp.pi) / (12 * nside**2), None, None, None
448+
Args:
449+
flm (jnp.ndarray): Spherical harmonic coefficients.
349450
451+
L (int): Harmonic band-limit.
350452
351-
# Link JAX gradients for C backend functions
352-
ssht_inverse.defvjp(_ssht_inverse_fwd, _ssht_inverse_bwd)
353-
ssht_forward.defvjp(_ssht_forward_fwd, _ssht_forward_bwd)
354-
healpy_inverse.defvjp(_healpy_inverse_fwd, _healpy_inverse_bwd)
355-
healpy_forward.defvjp(_healpy_forward_fwd, _healpy_forward_bwd)
453+
nside (int): HEALPix Nside resolution parameter.
454+
455+
Returns:
456+
jnp.ndarray: Signal on the sphere.
457+
458+
Note:
459+
[1] Gorski, Krzysztof M., et al. "HEALPix: A framework for high-resolution
460+
discretization and fast analysis of data distributed on the sphere." The
461+
Astrophysical Journal 622.2 (2005): 759
462+
463+
"""
464+
flm = reindex.flm_2d_to_hp_fast(flm, L)
465+
return healpy_alm2map(flm, L, nside)

tests/test_spherical_custom_grads.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -307,22 +307,16 @@ def func(f):
307307
@pytest.mark.parametrize("nside", nside_to_test)
308308
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
309309
def test_healpix_c_backend_inverse_custom_gradients(flm_generator, nside: int):
310-
sampling = "healpix"
311310
L = 2 * nside
312311
reality = True
313312
flm = flm_generator(L=L, reality=reality)
314-
flm_target = flm_generator(L=L, reality=reality)
315-
f_target = spherical.inverse_jax(
316-
flm_target, L, nside=nside, sampling=sampling, reality=reality
317-
)
318313

319314
def func(flm):
320-
f = spherical.inverse(
315+
return spherical.inverse(
321316
flm, L, 0, nside, sampling="healpix", method="jax_healpy", reality=True
322317
)
323-
return jnp.sum(jnp.abs(f - f_target) ** 2)
324318

325-
check_grads(func, (flm,), order=1, modes=("rev"))
319+
check_grads(func, (flm,), order=2, modes=("fwd", "rev"))
326320

327321

328322
@pytest.mark.parametrize("nside", nside_to_test)
@@ -334,16 +328,12 @@ def test_healpix_c_backend_forward_custom_gradients(
334328
sampling = "healpix"
335329
L = 2 * nside
336330
reality = True
337-
flm_target = flm_generator(L=L, reality=reality)
338331
flm = flm_generator(L=L, reality=reality)
339332
f = spherical.inverse_jax(flm, L, nside=nside, sampling=sampling, reality=reality)
340333

341334
def func(f):
342-
flm = spherical.forward(
335+
return spherical.forward(
343336
f, L, nside=nside, sampling="healpix", method="jax_healpy", iter=iter
344337
)
345-
return jnp.sum(jnp.abs(flm - flm_target) ** 2)
346-
347-
rtol = [1e-6, 1e-2, 5e-2, 1e-2][iter]
348338

349-
check_grads(func, (f,), order=1, modes=("rev"), rtol=rtol)
339+
check_grads(func, (f,), order=2, modes=("fwd", "rev"))

0 commit comments

Comments
 (0)