From 6b883e0155f6a728943e839b917dae7941c4d867 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Mon, 27 Jan 2025 16:54:01 +0000 Subject: [PATCH 1/2] Make pyssht and healpy optional dependencies --- notebooks/JAX_HEALPix_frontend.ipynb | 4 +-- notebooks/JAX_SSHT_frontend.ipynb | 4 +-- pyproject.toml | 9 +++--- s2fft/transforms/c_backend_spherical.py | 41 ++++++++++++++++++++++--- 4 files changed, 46 insertions(+), 12 deletions(-) diff --git a/notebooks/JAX_HEALPix_frontend.ipynb b/notebooks/JAX_HEALPix_frontend.ipynb index 2330f3a2..dd699719 100644 --- a/notebooks/JAX_HEALPix_frontend.ipynb +++ b/notebooks/JAX_HEALPix_frontend.ipynb @@ -19,9 +19,9 @@ "import sys\n", "IN_COLAB = 'google.colab' in sys.modules\n", "\n", - "# Install s2fft and data if running on google colab.\n", + "# Install s2fft and healpy if running on google colab.\n", "if IN_COLAB:\n", - " !pip install s2fft &> /dev/null" + " !pip install s2fft healpy &> /dev/null" ] }, { diff --git a/notebooks/JAX_SSHT_frontend.ipynb b/notebooks/JAX_SSHT_frontend.ipynb index 3a3b9774..6697a591 100644 --- a/notebooks/JAX_SSHT_frontend.ipynb +++ b/notebooks/JAX_SSHT_frontend.ipynb @@ -19,9 +19,9 @@ "import sys\n", "IN_COLAB = 'google.colab' in sys.modules\n", "\n", - "# Install s2fft and data if running on google colab.\n", + "# Install s2fft and pyssht if running on google colab.\n", "if IN_COLAB:\n", - " !pip install s2fft &> /dev/null" + " !pip install s2fft pyssht &> /dev/null" ] }, { diff --git a/pyproject.toml b/pyproject.toml index 6a2c310b..f66c9c1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,15 +27,12 @@ classifiers = [ ] description = "Differentiable and accelerated spherical transforms with JAX" dependencies = [ - "numpy>=1.20,<2", + "numpy>=1.20", "colorlog", "pyyaml", "jax>=0.3.13", "jaxlib", "torch", - "pyssht", - "healpy", - "ducc0", ] dynamic = [ "version", @@ -71,9 +68,13 @@ plotting = [ "ipywidgets", ] tests = [ + "ducc0", + "healpy", + "numpy<2", # Required currently due to lack of Numpy v2 compatible pyssht release "pytest", "pytest-cov", "so3", + "pyssht", ] [tool.scikit-build] diff --git a/s2fft/transforms/c_backend_spherical.py b/s2fft/transforms/c_backend_spherical.py index 06d1da9d..4ef7ae68 100644 --- a/s2fft/transforms/c_backend_spherical.py +++ b/s2fft/transforms/c_backend_spherical.py @@ -1,11 +1,11 @@ +"""C backend functions for which to provide JAX frontends.""" + +import importlib from functools import partial +from types import ModuleType -import healpy import jax.numpy as jnp import numpy as np - -# C backend functions for which to provide JAX frontend. -import pyssht from jax import core, custom_vjp from jax.interpreters import ad @@ -13,6 +13,21 @@ from s2fft.utils import iterative_refinement, quadrature_jax +class MissingWrapperDependencyError(Exception): + """Exception raised when a dependency for a wrapper function is missing.""" + + +def _try_import_module(module_name: str) -> ModuleType: + """Try to import a named module which may not be installed.""" + try: + module = importlib.import_module(module_name) + except ImportError as e: + raise MissingWrapperDependencyError( + "Wrapper function requires {module_name} to be installed" + ) from e + return module + + @custom_vjp def ssht_inverse( flm: jnp.ndarray, @@ -30,6 +45,8 @@ def ssht_inverse( custom JAX frontends, hence providing support for automatic differentiation. Currently these transforms can only be deployed on CPU, which is a limitation of the SSHT C package. + Requires `pyssht` package to be installed. + Args: flm (jnp.ndarray): Spherical harmonic coefficients. @@ -56,6 +73,7 @@ def ssht_inverse( IEEE Transactions on Signal Processing 59 (2011): 5876-5887. """ + pyssht = _try_import_module("pyssht") sampling_str = ["MW", "MWSS", "DH", "GL"] flm_1d = reindex.flm_2d_to_1d_fast(flm, L) _backend = "SSHT" if _ssht_backend == 0 else "ducc0" @@ -86,6 +104,7 @@ def _ssht_inverse_fwd( def _ssht_inverse_bwd(res, f): """Private function which implements the backward pass for inverse jax_ssht.""" + pyssht = _try_import_module("pyssht") _, L, spin, reality, ssht_sampling, _ssht_backend = res sampling_str = ["MW", "MWSS", "DH", "GL"] _backend = "SSHT" if _ssht_backend == 0 else "ducc0" @@ -149,6 +168,8 @@ def ssht_forward( custom JAX frontends, hence providing support for automatic differentiation. Currently these transforms can only be deployed on CPU, which is a limitation of the SSHT C package. + Requires `pyssht` package to be installed. + Args: f (jnp.ndarray): Signal on the sphere. @@ -175,6 +196,7 @@ def ssht_forward( IEEE Transactions on Signal Processing 59 (2011): 5876-5887. """ + pyssht = _try_import_module("pyssht") sampling_str = ["MW", "MWSS", "DH", "GL"] _backend = "SSHT" if _ssht_backend == 0 else "ducc0" flm = jnp.array( @@ -205,6 +227,7 @@ def _ssht_forward_fwd( def _ssht_forward_bwd(res, flm): """Private function which implements the backward pass for forward jax_ssht.""" + pyssht = _try_import_module("pyssht") _, L, spin, reality, ssht_sampling, _ssht_backend = res sampling_str = ["MW", "MWSS", "DH", "GL"] _backend = "SSHT" if _ssht_backend == 0 else "ducc0" @@ -300,6 +323,7 @@ def _real_dtype(complex_dtype): def _healpy_map2alm_impl(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray: + healpy = _try_import_module("healpy") return jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=0)) @@ -332,6 +356,8 @@ def healpy_map2alm(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray: array using HEALPix (ring-ordered) indexing. To instead return a two-dimensional array of harmonic coefficients use :py:func:`healpy_forward`. + Requires `healpy` package to be installed. + Args: f (jnp.ndarray): Signal on the sphere. @@ -347,6 +373,7 @@ def healpy_map2alm(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray: def _healpy_alm2map_impl(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray: + healpy = _try_import_module("healpy") return jnp.array(healpy.alm2map(np.array(flm), lmax=L - 1, nside=nside)) @@ -384,6 +411,8 @@ def healpy_alm2map(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray: dimensional array using HEALPix (ring-ordered) indexing. To instead pass a two-dimensional array of harmonic coefficients use :py:func:`healpy_inverse`. + Requires `healpy` package to be installed. + Args: flm (jnp.ndarray): Spherical harmonic coefficients. @@ -408,6 +437,8 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda Currently these transforms can only be deployed on CPU, which is a limitation of the C++ library. + Requires `healpy` package to be installed. + Args: f (jnp.ndarray): Signal on the sphere. @@ -448,6 +479,8 @@ def healpy_inverse(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray: Currently these transforms can only be deployed on CPU, which is a limitation of the C++ library. + Requires `healpy` package to be installed. + Args: flm (jnp.ndarray): Spherical harmonic coefficients. From afdf08b623a50eaf124964d10c3f327c58e7c6eb Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Tue, 28 Jan 2025 09:42:05 +0000 Subject: [PATCH 2/2] Add test for try import function --- tests/test_utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index ced9cda8..431d843d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,6 +6,10 @@ from jax.test_util import check_grads from s2fft.sampling import s2_samples as samples +from s2fft.transforms.c_backend_spherical import ( + MissingWrapperDependencyError, + _try_import_module, +) from s2fft.utils.rotation import generate_rotate_dls, rotate_flms jax.config.update("jax_enable_x64", True) @@ -120,3 +124,12 @@ def func(flm): return jnp.sum(jnp.abs(flm_rot - flm_target)) check_grads(func, (flm_start,), order=1, modes=("rev")) + + +def test_try_import_module(): + # Use an intentionally long and unlikely to clash module name + module_name = "_a_random_module_name_that_should_not_exist" + with pytest.raises( + MissingWrapperDependencyError, match="requires {module_name} to be installed" + ): + _try_import_module(module_name)