1+ """C backend functions for which to provide JAX frontends."""
2+
3+ import importlib
14from functools import partial
5+ from types import ModuleType
26
3- import healpy
47import jax .numpy as jnp
58import numpy as np
6-
7- # C backend functions for which to provide JAX frontend.
8- import pyssht
99from jax import core , custom_vjp
1010from jax .interpreters import ad
1111
1212from s2fft .sampling import reindex
1313from s2fft .utils import iterative_refinement , quadrature_jax
1414
1515
16+ class MissingWrapperDependencyError (Exception ):
17+ """Exception raised when a dependency for a wrapper function is missing."""
18+
19+
20+ def _try_import_module (module_name : str ) -> ModuleType :
21+ """Try to import a named module which may not be installed."""
22+ try :
23+ module = importlib .import_module (module_name )
24+ except ImportError as e :
25+ raise MissingWrapperDependencyError (
26+ "Wrapper function requires {module_name} to be installed"
27+ ) from e
28+ return module
29+
30+
1631@custom_vjp
1732def ssht_inverse (
1833 flm : jnp .ndarray ,
@@ -30,6 +45,8 @@ def ssht_inverse(
3045 custom JAX frontends, hence providing support for automatic differentiation. Currently
3146 these transforms can only be deployed on CPU, which is a limitation of the SSHT C package.
3247
48+ Requires `pyssht` package to be installed.
49+
3350 Args:
3451 flm (jnp.ndarray): Spherical harmonic coefficients.
3552
@@ -56,6 +73,7 @@ def ssht_inverse(
5673 IEEE Transactions on Signal Processing 59 (2011): 5876-5887.
5774
5875 """
76+ pyssht = _try_import_module ("pyssht" )
5977 sampling_str = ["MW" , "MWSS" , "DH" , "GL" ]
6078 flm_1d = reindex .flm_2d_to_1d_fast (flm , L )
6179 _backend = "SSHT" if _ssht_backend == 0 else "ducc0"
@@ -86,6 +104,7 @@ def _ssht_inverse_fwd(
86104
87105def _ssht_inverse_bwd (res , f ):
88106 """Private function which implements the backward pass for inverse jax_ssht."""
107+ pyssht = _try_import_module ("pyssht" )
89108 _ , L , spin , reality , ssht_sampling , _ssht_backend = res
90109 sampling_str = ["MW" , "MWSS" , "DH" , "GL" ]
91110 _backend = "SSHT" if _ssht_backend == 0 else "ducc0"
@@ -149,6 +168,8 @@ def ssht_forward(
149168 custom JAX frontends, hence providing support for automatic differentiation. Currently
150169 these transforms can only be deployed on CPU, which is a limitation of the SSHT C package.
151170
171+ Requires `pyssht` package to be installed.
172+
152173 Args:
153174 f (jnp.ndarray): Signal on the sphere.
154175
@@ -175,6 +196,7 @@ def ssht_forward(
175196 IEEE Transactions on Signal Processing 59 (2011): 5876-5887.
176197
177198 """
199+ pyssht = _try_import_module ("pyssht" )
178200 sampling_str = ["MW" , "MWSS" , "DH" , "GL" ]
179201 _backend = "SSHT" if _ssht_backend == 0 else "ducc0"
180202 flm = jnp .array (
@@ -205,6 +227,7 @@ def _ssht_forward_fwd(
205227
206228def _ssht_forward_bwd (res , flm ):
207229 """Private function which implements the backward pass for forward jax_ssht."""
230+ pyssht = _try_import_module ("pyssht" )
208231 _ , L , spin , reality , ssht_sampling , _ssht_backend = res
209232 sampling_str = ["MW" , "MWSS" , "DH" , "GL" ]
210233 _backend = "SSHT" if _ssht_backend == 0 else "ducc0"
@@ -300,6 +323,7 @@ def _real_dtype(complex_dtype):
300323
301324
302325def _healpy_map2alm_impl (f : jnp .ndarray , L : int , nside : int ) -> jnp .ndarray :
326+ healpy = _try_import_module ("healpy" )
303327 return jnp .array (healpy .map2alm (np .array (f ), lmax = L - 1 , iter = 0 ))
304328
305329
@@ -332,6 +356,8 @@ def healpy_map2alm(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
332356 array using HEALPix (ring-ordered) indexing. To instead return a two-dimensional
333357 array of harmonic coefficients use :py:func:`healpy_forward`.
334358
359+ Requires `healpy` package to be installed.
360+
335361 Args:
336362 f (jnp.ndarray): Signal on the sphere.
337363
@@ -347,6 +373,7 @@ def healpy_map2alm(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
347373
348374
349375def _healpy_alm2map_impl (flm : jnp .ndarray , L : int , nside : int ) -> jnp .ndarray :
376+ healpy = _try_import_module ("healpy" )
350377 return jnp .array (healpy .alm2map (np .array (flm ), lmax = L - 1 , nside = nside ))
351378
352379
@@ -384,6 +411,8 @@ def healpy_alm2map(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
384411 dimensional array using HEALPix (ring-ordered) indexing. To instead pass a
385412 two-dimensional array of harmonic coefficients use :py:func:`healpy_inverse`.
386413
414+ Requires `healpy` package to be installed.
415+
387416 Args:
388417 flm (jnp.ndarray): Spherical harmonic coefficients.
389418
@@ -408,6 +437,8 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda
408437 Currently these transforms can only be deployed on CPU, which is a limitation of the
409438 C++ library.
410439
440+ Requires `healpy` package to be installed.
441+
411442 Args:
412443 f (jnp.ndarray): Signal on the sphere.
413444
@@ -448,6 +479,8 @@ def healpy_inverse(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
448479 Currently these transforms can only be deployed on CPU, which is a limitation of the
449480 C++ library.
450481
482+ Requires `healpy` package to be installed.
483+
451484 Args:
452485 flm (jnp.ndarray): Spherical harmonic coefficients.
453486
0 commit comments