Skip to content

Commit 6b883e0

Browse files
committed
Make pyssht and healpy optional dependencies
1 parent 876e090 commit 6b883e0

File tree

4 files changed

+46
-12
lines changed

4 files changed

+46
-12
lines changed

notebooks/JAX_HEALPix_frontend.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
"import sys\n",
2020
"IN_COLAB = 'google.colab' in sys.modules\n",
2121
"\n",
22-
"# Install s2fft and data if running on google colab.\n",
22+
"# Install s2fft and healpy if running on google colab.\n",
2323
"if IN_COLAB:\n",
24-
" !pip install s2fft &> /dev/null"
24+
" !pip install s2fft healpy &> /dev/null"
2525
]
2626
},
2727
{

notebooks/JAX_SSHT_frontend.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
"import sys\n",
2020
"IN_COLAB = 'google.colab' in sys.modules\n",
2121
"\n",
22-
"# Install s2fft and data if running on google colab.\n",
22+
"# Install s2fft and pyssht if running on google colab.\n",
2323
"if IN_COLAB:\n",
24-
" !pip install s2fft &> /dev/null"
24+
" !pip install s2fft pyssht &> /dev/null"
2525
]
2626
},
2727
{

pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,12 @@ classifiers = [
2727
]
2828
description = "Differentiable and accelerated spherical transforms with JAX"
2929
dependencies = [
30-
"numpy>=1.20,<2",
30+
"numpy>=1.20",
3131
"colorlog",
3232
"pyyaml",
3333
"jax>=0.3.13",
3434
"jaxlib",
3535
"torch",
36-
"pyssht",
37-
"healpy",
38-
"ducc0",
3936
]
4037
dynamic = [
4138
"version",
@@ -71,9 +68,13 @@ plotting = [
7168
"ipywidgets",
7269
]
7370
tests = [
71+
"ducc0",
72+
"healpy",
73+
"numpy<2", # Required currently due to lack of Numpy v2 compatible pyssht release
7474
"pytest",
7575
"pytest-cov",
7676
"so3",
77+
"pyssht",
7778
]
7879

7980
[tool.scikit-build]

s2fft/transforms/c_backend_spherical.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,33 @@
1+
"""C backend functions for which to provide JAX frontends."""
2+
3+
import importlib
14
from functools import partial
5+
from types import ModuleType
26

3-
import healpy
47
import jax.numpy as jnp
58
import numpy as np
6-
7-
# C backend functions for which to provide JAX frontend.
8-
import pyssht
99
from jax import core, custom_vjp
1010
from jax.interpreters import ad
1111

1212
from s2fft.sampling import reindex
1313
from s2fft.utils import iterative_refinement, quadrature_jax
1414

1515

16+
class MissingWrapperDependencyError(Exception):
17+
"""Exception raised when a dependency for a wrapper function is missing."""
18+
19+
20+
def _try_import_module(module_name: str) -> ModuleType:
21+
"""Try to import a named module which may not be installed."""
22+
try:
23+
module = importlib.import_module(module_name)
24+
except ImportError as e:
25+
raise MissingWrapperDependencyError(
26+
"Wrapper function requires {module_name} to be installed"
27+
) from e
28+
return module
29+
30+
1631
@custom_vjp
1732
def ssht_inverse(
1833
flm: jnp.ndarray,
@@ -30,6 +45,8 @@ def ssht_inverse(
3045
custom JAX frontends, hence providing support for automatic differentiation. Currently
3146
these transforms can only be deployed on CPU, which is a limitation of the SSHT C package.
3247
48+
Requires `pyssht` package to be installed.
49+
3350
Args:
3451
flm (jnp.ndarray): Spherical harmonic coefficients.
3552
@@ -56,6 +73,7 @@ def ssht_inverse(
5673
IEEE Transactions on Signal Processing 59 (2011): 5876-5887.
5774
5875
"""
76+
pyssht = _try_import_module("pyssht")
5977
sampling_str = ["MW", "MWSS", "DH", "GL"]
6078
flm_1d = reindex.flm_2d_to_1d_fast(flm, L)
6179
_backend = "SSHT" if _ssht_backend == 0 else "ducc0"
@@ -86,6 +104,7 @@ def _ssht_inverse_fwd(
86104

87105
def _ssht_inverse_bwd(res, f):
88106
"""Private function which implements the backward pass for inverse jax_ssht."""
107+
pyssht = _try_import_module("pyssht")
89108
_, L, spin, reality, ssht_sampling, _ssht_backend = res
90109
sampling_str = ["MW", "MWSS", "DH", "GL"]
91110
_backend = "SSHT" if _ssht_backend == 0 else "ducc0"
@@ -149,6 +168,8 @@ def ssht_forward(
149168
custom JAX frontends, hence providing support for automatic differentiation. Currently
150169
these transforms can only be deployed on CPU, which is a limitation of the SSHT C package.
151170
171+
Requires `pyssht` package to be installed.
172+
152173
Args:
153174
f (jnp.ndarray): Signal on the sphere.
154175
@@ -175,6 +196,7 @@ def ssht_forward(
175196
IEEE Transactions on Signal Processing 59 (2011): 5876-5887.
176197
177198
"""
199+
pyssht = _try_import_module("pyssht")
178200
sampling_str = ["MW", "MWSS", "DH", "GL"]
179201
_backend = "SSHT" if _ssht_backend == 0 else "ducc0"
180202
flm = jnp.array(
@@ -205,6 +227,7 @@ def _ssht_forward_fwd(
205227

206228
def _ssht_forward_bwd(res, flm):
207229
"""Private function which implements the backward pass for forward jax_ssht."""
230+
pyssht = _try_import_module("pyssht")
208231
_, L, spin, reality, ssht_sampling, _ssht_backend = res
209232
sampling_str = ["MW", "MWSS", "DH", "GL"]
210233
_backend = "SSHT" if _ssht_backend == 0 else "ducc0"
@@ -300,6 +323,7 @@ def _real_dtype(complex_dtype):
300323

301324

302325
def _healpy_map2alm_impl(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
326+
healpy = _try_import_module("healpy")
303327
return jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=0))
304328

305329

@@ -332,6 +356,8 @@ def healpy_map2alm(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
332356
array using HEALPix (ring-ordered) indexing. To instead return a two-dimensional
333357
array of harmonic coefficients use :py:func:`healpy_forward`.
334358
359+
Requires `healpy` package to be installed.
360+
335361
Args:
336362
f (jnp.ndarray): Signal on the sphere.
337363
@@ -347,6 +373,7 @@ def healpy_map2alm(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
347373

348374

349375
def _healpy_alm2map_impl(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
376+
healpy = _try_import_module("healpy")
350377
return jnp.array(healpy.alm2map(np.array(flm), lmax=L - 1, nside=nside))
351378

352379

@@ -384,6 +411,8 @@ def healpy_alm2map(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
384411
dimensional array using HEALPix (ring-ordered) indexing. To instead pass a
385412
two-dimensional array of harmonic coefficients use :py:func:`healpy_inverse`.
386413
414+
Requires `healpy` package to be installed.
415+
387416
Args:
388417
flm (jnp.ndarray): Spherical harmonic coefficients.
389418
@@ -408,6 +437,8 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda
408437
Currently these transforms can only be deployed on CPU, which is a limitation of the
409438
C++ library.
410439
440+
Requires `healpy` package to be installed.
441+
411442
Args:
412443
f (jnp.ndarray): Signal on the sphere.
413444
@@ -448,6 +479,8 @@ def healpy_inverse(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
448479
Currently these transforms can only be deployed on CPU, which is a limitation of the
449480
C++ library.
450481
482+
Requires `healpy` package to be installed.
483+
451484
Args:
452485
flm (jnp.ndarray): Spherical harmonic coefficients.
453486

0 commit comments

Comments
 (0)