Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions notebooks/JAX_HEALPix_frontend.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
"import sys\n",
"IN_COLAB = 'google.colab' in sys.modules\n",
"\n",
"# Install s2fft and data if running on google colab.\n",
"# Install s2fft and healpy if running on google colab.\n",
"if IN_COLAB:\n",
" !pip install s2fft &> /dev/null"
" !pip install s2fft healpy &> /dev/null"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions notebooks/JAX_SSHT_frontend.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
"import sys\n",
"IN_COLAB = 'google.colab' in sys.modules\n",
"\n",
"# Install s2fft and data if running on google colab.\n",
"# Install s2fft and pyssht if running on google colab.\n",
"if IN_COLAB:\n",
" !pip install s2fft &> /dev/null"
" !pip install s2fft pyssht &> /dev/null"
]
},
{
Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@ classifiers = [
]
description = "Differentiable and accelerated spherical transforms with JAX"
dependencies = [
"numpy>=1.20,<2",
"numpy>=1.20",
"colorlog",
"pyyaml",
"jax>=0.3.13",
"jaxlib",
"torch",
"pyssht",
"healpy",
"ducc0",
]
dynamic = [
"version",
Expand Down Expand Up @@ -71,9 +68,13 @@ plotting = [
"ipywidgets",
]
tests = [
"ducc0",
"healpy",
"numpy<2", # Required currently due to lack of Numpy v2 compatible pyssht release
"pytest",
"pytest-cov",
"so3",
"pyssht",
]

[tool.scikit-build]
Expand Down
41 changes: 37 additions & 4 deletions s2fft/transforms/c_backend_spherical.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
"""C backend functions for which to provide JAX frontends."""

import importlib
from functools import partial
from types import ModuleType

import healpy
import jax.numpy as jnp
import numpy as np

# C backend functions for which to provide JAX frontend.
import pyssht
from jax import core, custom_vjp
from jax.interpreters import ad

from s2fft.sampling import reindex
from s2fft.utils import iterative_refinement, quadrature_jax


class MissingWrapperDependencyError(Exception):
"""Exception raised when a dependency for a wrapper function is missing."""


def _try_import_module(module_name: str) -> ModuleType:
"""Try to import a named module which may not be installed."""
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise MissingWrapperDependencyError(
"Wrapper function requires {module_name} to be installed"
) from e
return module


@custom_vjp
def ssht_inverse(
flm: jnp.ndarray,
Expand All @@ -30,6 +45,8 @@ def ssht_inverse(
custom JAX frontends, hence providing support for automatic differentiation. Currently
these transforms can only be deployed on CPU, which is a limitation of the SSHT C package.

Requires `pyssht` package to be installed.

Args:
flm (jnp.ndarray): Spherical harmonic coefficients.

Expand All @@ -56,6 +73,7 @@ def ssht_inverse(
IEEE Transactions on Signal Processing 59 (2011): 5876-5887.

"""
pyssht = _try_import_module("pyssht")
sampling_str = ["MW", "MWSS", "DH", "GL"]
flm_1d = reindex.flm_2d_to_1d_fast(flm, L)
_backend = "SSHT" if _ssht_backend == 0 else "ducc0"
Expand Down Expand Up @@ -86,6 +104,7 @@ def _ssht_inverse_fwd(

def _ssht_inverse_bwd(res, f):
"""Private function which implements the backward pass for inverse jax_ssht."""
pyssht = _try_import_module("pyssht")
_, L, spin, reality, ssht_sampling, _ssht_backend = res
sampling_str = ["MW", "MWSS", "DH", "GL"]
_backend = "SSHT" if _ssht_backend == 0 else "ducc0"
Expand Down Expand Up @@ -149,6 +168,8 @@ def ssht_forward(
custom JAX frontends, hence providing support for automatic differentiation. Currently
these transforms can only be deployed on CPU, which is a limitation of the SSHT C package.

Requires `pyssht` package to be installed.

Args:
f (jnp.ndarray): Signal on the sphere.

Expand All @@ -175,6 +196,7 @@ def ssht_forward(
IEEE Transactions on Signal Processing 59 (2011): 5876-5887.

"""
pyssht = _try_import_module("pyssht")
sampling_str = ["MW", "MWSS", "DH", "GL"]
_backend = "SSHT" if _ssht_backend == 0 else "ducc0"
flm = jnp.array(
Expand Down Expand Up @@ -205,6 +227,7 @@ def _ssht_forward_fwd(

def _ssht_forward_bwd(res, flm):
"""Private function which implements the backward pass for forward jax_ssht."""
pyssht = _try_import_module("pyssht")
_, L, spin, reality, ssht_sampling, _ssht_backend = res
sampling_str = ["MW", "MWSS", "DH", "GL"]
_backend = "SSHT" if _ssht_backend == 0 else "ducc0"
Expand Down Expand Up @@ -300,6 +323,7 @@ def _real_dtype(complex_dtype):


def _healpy_map2alm_impl(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
healpy = _try_import_module("healpy")
return jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=0))


Expand Down Expand Up @@ -332,6 +356,8 @@ def healpy_map2alm(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
array using HEALPix (ring-ordered) indexing. To instead return a two-dimensional
array of harmonic coefficients use :py:func:`healpy_forward`.

Requires `healpy` package to be installed.

Args:
f (jnp.ndarray): Signal on the sphere.

Expand All @@ -347,6 +373,7 @@ def healpy_map2alm(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:


def _healpy_alm2map_impl(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
healpy = _try_import_module("healpy")
return jnp.array(healpy.alm2map(np.array(flm), lmax=L - 1, nside=nside))


Expand Down Expand Up @@ -384,6 +411,8 @@ def healpy_alm2map(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
dimensional array using HEALPix (ring-ordered) indexing. To instead pass a
two-dimensional array of harmonic coefficients use :py:func:`healpy_inverse`.

Requires `healpy` package to be installed.

Args:
flm (jnp.ndarray): Spherical harmonic coefficients.

Expand All @@ -408,6 +437,8 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda
Currently these transforms can only be deployed on CPU, which is a limitation of the
C++ library.

Requires `healpy` package to be installed.

Args:
f (jnp.ndarray): Signal on the sphere.

Expand Down Expand Up @@ -448,6 +479,8 @@ def healpy_inverse(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
Currently these transforms can only be deployed on CPU, which is a limitation of the
C++ library.

Requires `healpy` package to be installed.

Args:
flm (jnp.ndarray): Spherical harmonic coefficients.

Expand Down
13 changes: 13 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from jax.test_util import check_grads

from s2fft.sampling import s2_samples as samples
from s2fft.transforms.c_backend_spherical import (
MissingWrapperDependencyError,
_try_import_module,
)
from s2fft.utils.rotation import generate_rotate_dls, rotate_flms

jax.config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -120,3 +124,12 @@ def func(flm):
return jnp.sum(jnp.abs(flm_rot - flm_target))

check_grads(func, (flm_start,), order=1, modes=("rev"))


def test_try_import_module():
# Use an intentionally long and unlikely to clash module name
module_name = "_a_random_module_name_that_should_not_exist"
with pytest.raises(
MissingWrapperDependencyError, match="requires {module_name} to be installed"
):
_try_import_module(module_name)
Loading