Skip to content

Commit 601fb95

Browse files
authored
Make pyssht and healpy optional dependencies (#267)
* Make pyssht and healpy optional dependencies * Add test for try import function
1 parent f48a7fc commit 601fb95

File tree

5 files changed

+59
-12
lines changed

5 files changed

+59
-12
lines changed

notebooks/JAX_HEALPix_frontend.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
"import sys\n",
2020
"IN_COLAB = 'google.colab' in sys.modules\n",
2121
"\n",
22-
"# Install s2fft and data if running on google colab.\n",
22+
"# Install s2fft and healpy if running on google colab.\n",
2323
"if IN_COLAB:\n",
24-
" !pip install s2fft &> /dev/null"
24+
" !pip install s2fft healpy &> /dev/null"
2525
]
2626
},
2727
{

notebooks/JAX_SSHT_frontend.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
"import sys\n",
2020
"IN_COLAB = 'google.colab' in sys.modules\n",
2121
"\n",
22-
"# Install s2fft and data if running on google colab.\n",
22+
"# Install s2fft and pyssht if running on google colab.\n",
2323
"if IN_COLAB:\n",
24-
" !pip install s2fft &> /dev/null"
24+
" !pip install s2fft pyssht &> /dev/null"
2525
]
2626
},
2727
{

pyproject.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,12 @@ classifiers = [
2727
]
2828
description = "Differentiable and accelerated spherical transforms with JAX"
2929
dependencies = [
30-
"numpy>=1.20,<2",
30+
"numpy>=1.20",
3131
"colorlog",
3232
"pyyaml",
3333
"jax>=0.3.13",
3434
"jaxlib",
3535
"torch",
36-
"pyssht",
37-
"healpy",
38-
"ducc0",
3936
]
4037
dynamic = [
4138
"version",
@@ -71,9 +68,13 @@ plotting = [
7168
"ipywidgets",
7269
]
7370
tests = [
71+
"ducc0",
72+
"healpy",
73+
"numpy<2", # Required currently due to lack of Numpy v2 compatible pyssht release
7474
"pytest",
7575
"pytest-cov",
7676
"so3",
77+
"pyssht",
7778
]
7879

7980
[tool.scikit-build]

s2fft/transforms/c_backend_spherical.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,33 @@
1+
"""C backend functions for which to provide JAX frontends."""
2+
3+
import importlib
14
from functools import partial
5+
from types import ModuleType
26

3-
import healpy
47
import jax.numpy as jnp
58
import numpy as np
6-
7-
# C backend functions for which to provide JAX frontend.
8-
import pyssht
99
from jax import core, custom_vjp
1010
from jax.interpreters import ad
1111

1212
from s2fft.sampling import reindex
1313
from s2fft.utils import iterative_refinement, quadrature_jax
1414

1515

16+
class MissingWrapperDependencyError(Exception):
17+
"""Exception raised when a dependency for a wrapper function is missing."""
18+
19+
20+
def _try_import_module(module_name: str) -> ModuleType:
21+
"""Try to import a named module which may not be installed."""
22+
try:
23+
module = importlib.import_module(module_name)
24+
except ImportError as e:
25+
raise MissingWrapperDependencyError(
26+
"Wrapper function requires {module_name} to be installed"
27+
) from e
28+
return module
29+
30+
1631
@custom_vjp
1732
def ssht_inverse(
1833
flm: jnp.ndarray,
@@ -30,6 +45,8 @@ def ssht_inverse(
3045
custom JAX frontends, hence providing support for automatic differentiation. Currently
3146
these transforms can only be deployed on CPU, which is a limitation of the SSHT C package.
3247
48+
Requires `pyssht` package to be installed.
49+
3350
Args:
3451
flm (jnp.ndarray): Spherical harmonic coefficients.
3552
@@ -56,6 +73,7 @@ def ssht_inverse(
5673
IEEE Transactions on Signal Processing 59 (2011): 5876-5887.
5774
5875
"""
76+
pyssht = _try_import_module("pyssht")
5977
sampling_str = ["MW", "MWSS", "DH", "GL"]
6078
flm_1d = reindex.flm_2d_to_1d_fast(flm, L)
6179
_backend = "SSHT" if _ssht_backend == 0 else "ducc0"
@@ -86,6 +104,7 @@ def _ssht_inverse_fwd(
86104

87105
def _ssht_inverse_bwd(res, f):
88106
"""Private function which implements the backward pass for inverse jax_ssht."""
107+
pyssht = _try_import_module("pyssht")
89108
_, L, spin, reality, ssht_sampling, _ssht_backend = res
90109
sampling_str = ["MW", "MWSS", "DH", "GL"]
91110
_backend = "SSHT" if _ssht_backend == 0 else "ducc0"
@@ -149,6 +168,8 @@ def ssht_forward(
149168
custom JAX frontends, hence providing support for automatic differentiation. Currently
150169
these transforms can only be deployed on CPU, which is a limitation of the SSHT C package.
151170
171+
Requires `pyssht` package to be installed.
172+
152173
Args:
153174
f (jnp.ndarray): Signal on the sphere.
154175
@@ -175,6 +196,7 @@ def ssht_forward(
175196
IEEE Transactions on Signal Processing 59 (2011): 5876-5887.
176197
177198
"""
199+
pyssht = _try_import_module("pyssht")
178200
sampling_str = ["MW", "MWSS", "DH", "GL"]
179201
_backend = "SSHT" if _ssht_backend == 0 else "ducc0"
180202
flm = jnp.array(
@@ -205,6 +227,7 @@ def _ssht_forward_fwd(
205227

206228
def _ssht_forward_bwd(res, flm):
207229
"""Private function which implements the backward pass for forward jax_ssht."""
230+
pyssht = _try_import_module("pyssht")
208231
_, L, spin, reality, ssht_sampling, _ssht_backend = res
209232
sampling_str = ["MW", "MWSS", "DH", "GL"]
210233
_backend = "SSHT" if _ssht_backend == 0 else "ducc0"
@@ -300,6 +323,7 @@ def _real_dtype(complex_dtype):
300323

301324

302325
def _healpy_map2alm_impl(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
326+
healpy = _try_import_module("healpy")
303327
return jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=0))
304328

305329

@@ -332,6 +356,8 @@ def healpy_map2alm(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
332356
array using HEALPix (ring-ordered) indexing. To instead return a two-dimensional
333357
array of harmonic coefficients use :py:func:`healpy_forward`.
334358
359+
Requires `healpy` package to be installed.
360+
335361
Args:
336362
f (jnp.ndarray): Signal on the sphere.
337363
@@ -347,6 +373,7 @@ def healpy_map2alm(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
347373

348374

349375
def _healpy_alm2map_impl(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
376+
healpy = _try_import_module("healpy")
350377
return jnp.array(healpy.alm2map(np.array(flm), lmax=L - 1, nside=nside))
351378

352379

@@ -384,6 +411,8 @@ def healpy_alm2map(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
384411
dimensional array using HEALPix (ring-ordered) indexing. To instead pass a
385412
two-dimensional array of harmonic coefficients use :py:func:`healpy_inverse`.
386413
414+
Requires `healpy` package to be installed.
415+
387416
Args:
388417
flm (jnp.ndarray): Spherical harmonic coefficients.
389418
@@ -408,6 +437,8 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda
408437
Currently these transforms can only be deployed on CPU, which is a limitation of the
409438
C++ library.
410439
440+
Requires `healpy` package to be installed.
441+
411442
Args:
412443
f (jnp.ndarray): Signal on the sphere.
413444
@@ -448,6 +479,8 @@ def healpy_inverse(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
448479
Currently these transforms can only be deployed on CPU, which is a limitation of the
449480
C++ library.
450481
482+
Requires `healpy` package to be installed.
483+
451484
Args:
452485
flm (jnp.ndarray): Spherical harmonic coefficients.
453486

tests/test_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from jax.test_util import check_grads
77

88
from s2fft.sampling import s2_samples as samples
9+
from s2fft.transforms.c_backend_spherical import (
10+
MissingWrapperDependencyError,
11+
_try_import_module,
12+
)
913
from s2fft.utils.rotation import generate_rotate_dls, rotate_flms
1014

1115
jax.config.update("jax_enable_x64", True)
@@ -120,3 +124,12 @@ def func(flm):
120124
return jnp.sum(jnp.abs(flm_rot - flm_target))
121125

122126
check_grads(func, (flm_start,), order=1, modes=("rev"))
127+
128+
129+
def test_try_import_module():
130+
# Use an intentionally long and unlikely to clash module name
131+
module_name = "_a_random_module_name_that_should_not_exist"
132+
with pytest.raises(
133+
MissingWrapperDependencyError, match="requires {module_name} to be installed"
134+
):
135+
_try_import_module(module_name)

0 commit comments

Comments
 (0)