Skip to content
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
b039456
Make spherical precompute benchmarks compatible with pytorch
matt-graham Feb 24, 2025
5333b69
Add utilities for wrapping JAX function for use in PyTorch
matt-graham Feb 26, 2025
34459f0
Add initial wrapped versions of torch precompute transforms
matt-graham Feb 26, 2025
2801dfc
Update array conversion in benchmarks to avoid byte alignment warning
matt-graham Mar 6, 2025
8092bda
Remove previous torch precompute spherical transform implementations
matt-graham Mar 6, 2025
9c8cef7
Make precompute Wigner benchmarks compatible with torch
matt-graham Mar 6, 2025
1b4dc7e
Add torch wrapper precompute Wigner transforms
matt-graham Mar 6, 2025
d491426
Removing docstring for now removed using_torch arg
matt-graham Mar 6, 2025
421229c
Correct typo in docstring
matt-graham Mar 6, 2025
c020be2
Remove previous torch precompute Wigner transform implementations
matt-graham Mar 6, 2025
76b02b3
Update references to JAX in docstring when wrapping for torch use
matt-graham Mar 6, 2025
e7aa2fe
Try to infer differentiable args from annotations
matt-graham Mar 11, 2025
cc7d5c0
Update annotations of wrapped functions
matt-graham Mar 11, 2025
59ae417
Remove explicit differentiable argument name defs
matt-graham Mar 11, 2025
c2d55f1
Add helper function for wrapping all JAX functions in module
matt-graham Mar 11, 2025
1ebd797
Use wrappers for Torch resampling and quadrature modules
matt-graham Mar 11, 2025
b96c530
Remove using_torch option from signal generator functions
matt-graham Mar 11, 2025
5fde7ad
Update torch demo notebook to follow new wrapper interface
matt-graham Mar 11, 2025
6761781
Use wrappers for torch HEALPix FFT functions
matt-graham Mar 11, 2025
a0fe321
Make torch an optional dependency
matt-graham Mar 11, 2025
088dd8d
Use backwards compatible tree_map import
matt-graham Mar 11, 2025
bc440df
Copy annotations in wrapped function and check for existence of doc
matt-graham Mar 20, 2025
b24c174
Start on torch wrapper tests
matt-graham Mar 20, 2025
1ece73b
Add annotations futures import
matt-graham Mar 28, 2025
bd49bf7
Add additional torch wrapper tests
matt-graham Mar 28, 2025
8466258
Merge branch 'main' into mmg/pytorch-wrapper
matt-graham Apr 3, 2025
b1b7207
Make type alias Python 3.8 compatible
matt-graham Apr 4, 2025
935ad4b
Account for differing complex derivatives conventions between torch a…
matt-graham Apr 4, 2025
9532ca0
Add additional complex test cases and gradient checks to wrapper tests
matt-graham Apr 4, 2025
5c951df
Ignore complex warning due to casts in tests rather than erroring
matt-graham Apr 4, 2025
4d01933
Reduce max number iter tested for HEALPix to reduce test times
matt-graham Apr 4, 2025
a1b5d0b
Maintain compatibility with older JAX versions
matt-graham Apr 4, 2025
8d9a2ff
More maintaining compatibility with older JAX versions
matt-graham Apr 4, 2025
aff48ac
Correct typo in comment
matt-graham Apr 4, 2025
0d4698f
Explicitly cast kernels in einsum ops to avoid ComplexWarning causing…
matt-graham Apr 4, 2025
4990ca4
Force JAX double precision mode in Wigner precompute tests
matt-graham Apr 4, 2025
7208d9f
Add test for function checking torch available
matt-graham Apr 8, 2025
1c30841
Fix torch optional import logic to avoid errors when not installed
matt-graham Apr 11, 2025
c1ec18a
Refactor method dispatch logic in inverse transform
matt-graham Apr 16, 2025
9dd1656
Refactor method dispatch logic in HEALPix FFTs
matt-graham Apr 16, 2025
bf80707
Expose option to use HEALPix custom primitive in inverse transform
matt-graham Apr 16, 2025
addd2ce
Pass through method to select HEALPix (I)FFT function
matt-graham Apr 23, 2025
9efb251
Expose jax_cuda method in top-level spherical inverse function
matt-graham Apr 23, 2025
09b75d5
Refactor method dispatch logic in forward transform
matt-graham Apr 23, 2025
1715ead
Add OTF spherical transform torch wrappers
matt-graham Apr 23, 2025
a7c548f
Make torch wrapper diff arg inferring robust to non-type annotations
matt-graham Apr 23, 2025
364a6f6
Mark use_healpix_custom_primitive arg as static
matt-graham Apr 23, 2025
ff8513d
Include torch wrappers in tests
matt-graham Apr 23, 2025
cd3988d
Refactor method dispatch logic in Wigner transforms
matt-graham Apr 23, 2025
034ecbd
Add torch wrappers for Wigner OTF transforms
matt-graham Apr 23, 2025
6444bb0
Update README and notebook to indicate wider torch support
matt-graham Apr 23, 2025
a91729f
Pin JAX version to less than v0.6.0 due to breaking changes
matt-graham Apr 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions benchmarks/precompute_spherical.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Benchmarks for precompute spherical transforms."""

import jax
import numpy as np
from benchmarking import (
BenchmarkSetup,
Expand All @@ -10,6 +11,7 @@

import s2fft
import s2fft.precompute_transforms
from s2fft.utils import torch_wrapper

L_VALUES = [8, 16, 32, 64, 128, 256]
SPIN_VALUES = [0]
Expand All @@ -31,11 +33,17 @@ def setup_forward(method, L, sampling, spin, reality, recursion):
sampling=sampling,
reality=reality,
)
kernel_function = (
s2fft.precompute_transforms.construct.spin_spherical_kernel_jax
if method == "jax"
else s2fft.precompute_transforms.construct.spin_spherical_kernel
)
# As torch method wraps JAX functions and converting NumPy array to torch tensor
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
# converting to JAX array using mutual DLPack support we first convert the NumPy
# arrays to a JAX arrays before converting to torch tensors which eliminates this
# warning
if method.startswith("jax") or method.startswith("torch"):
flm = jax.numpy.asarray(flm)
f = jax.numpy.asarray(f)
if method.startswith("torch"):
flm, f = torch_wrapper.tree_map_jax_array_to_torch_tensor((flm, f))
kernel_function = s2fft.precompute_transforms.spherical._kernel_functions[method]
kernel = kernel_function(
L=L,
spin=spin,
Expand Down Expand Up @@ -73,11 +81,16 @@ def setup_inverse(method, L, sampling, spin, reality, recursion):
skip("Reality only valid for scalar fields (spin=0).")
rng = np.random.default_rng()
flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality)
kernel_function = (
s2fft.precompute_transforms.construct.spin_spherical_kernel_jax
if method == "jax"
else s2fft.precompute_transforms.construct.spin_spherical_kernel
)
# As torch method wraps JAX functions and converting NumPy array to torch tensor
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
# converting to JAX array using mutual DLPack support we first convert the NumPy
# array to a JAX array before converting to a torch tensor which eliminates this
# warning
if method.startswith("jax") or method.startswith("torch"):
flm = jax.numpy.asarray(flm)
if method.startswith("torch"):
flm = torch_wrapper.jax_array_to_torch_tensor(flm)
kernel_function = s2fft.precompute_transforms.spherical._kernel_functions[method]
kernel = kernel_function(
L=L,
spin=spin,
Expand Down
33 changes: 23 additions & 10 deletions benchmarks/precompute_wigner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Benchmarks for precompute Wigner-d transforms."""

import jax
import numpy as np
from benchmarking import (
BenchmarkSetup,
Expand All @@ -10,6 +11,7 @@
import s2fft
import s2fft.precompute_transforms
from s2fft.base_transforms import wigner as base_wigner
from s2fft.utils import torch_wrapper

L_VALUES = [16, 32, 64, 128, 256]
N_VALUES = [2]
Expand All @@ -31,11 +33,17 @@ def setup_forward(method, L, N, L_lower, sampling, reality, mode):
sampling=sampling,
reality=reality,
)
kernel_function = (
s2fft.precompute_transforms.construct.wigner_kernel_jax
if "jax" in method
else s2fft.precompute_transforms.construct.wigner_kernel
)
# As torch method wraps JAX functions and converting NumPy array to torch tensor
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
# converting to JAX array using mutual DLPack support we first convert the NumPy
# arrays to a JAX arrays before converting to torch tensors which eliminates this
# warning
if method.startswith("jax") or method.startswith("torch"):
flmn = jax.numpy.asarray(flmn)
f = jax.numpy.asarray(f)
if method.startswith("torch"):
flmn, f = torch_wrapper.tree_map_jax_array_to_torch_tensor((flmn, f))
kernel_function = s2fft.precompute_transforms.wigner._kernel_functions[method]
kernel = kernel_function(
L=L, N=N, reality=reality, sampling=sampling, forward=True, mode=mode
)
Expand Down Expand Up @@ -67,11 +75,16 @@ def forward(f, kernel, method, L, N, L_lower, sampling, reality, mode):
def setup_inverse(method, L, N, L_lower, sampling, reality, mode):
rng = np.random.default_rng()
flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality)
kernel_function = (
s2fft.precompute_transforms.construct.wigner_kernel_jax
if method == "jax"
else s2fft.precompute_transforms.construct.wigner_kernel
)
# As torch method wraps JAX functions and converting NumPy array to torch tensor
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
# converting to JAX array using mutual DLPack support we first convert the NumPy
# arrays to a JAX arrays before converting to torch tensors which eliminates this
# warning
if method.startswith("jax") or method.startswith("torch"):
flmn = jax.numpy.asarray(flmn)
if method.startswith("torch"):
flmn = torch_wrapper.jax_array_to_torch_tensor(flmn)
kernel_function = s2fft.precompute_transforms.wigner._kernel_functions[method]
kernel = kernel_function(
L=L, N=N, reality=reality, sampling=sampling, forward=False, mode=mode
)
Expand Down
49 changes: 23 additions & 26 deletions notebooks/torch_frontend.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,21 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"This minimal tutorial demonstrates how to use the torch frontend for `S2FFT` to compute spherical harmonic transforms. Though `S2FFT` is primarily designed for JAX, this torch functionality is fully unit tested (including gradients) and can be used straightforwardly as a learnable layer within existing models."
"This minimal tutorial demonstrates how to use the torch frontend for `S2FFT` to compute spherical harmonic transforms. Though `S2FFT` is primarily designed for JAX, this torch functionality is fully unit tested (including gradients) and can be used straightforwardly as a learnable layer within existing models. As the torch functions wrap the JAX implementations we need to configure JAX to use 64-bit precision floating point types by default to ensure sufficient precision for the transforms - `S2FFT` will emit a warning if this has not been done."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L.\n"
]
}
],
"outputs": [],
"source": [
"import jax\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"import torch \n",
"import numpy as np \n",
"from s2fft.precompute_transforms.spherical import inverse, forward\n",
"from s2fft.precompute_transforms.construct import spin_spherical_kernel\n",
"from s2fft.precompute_transforms.construct import spin_spherical_kernel_torch\n",
"from s2fft.utils import signal_generator"
]
},
Expand All @@ -65,9 +59,9 @@
"metadata": {},
"outputs": [],
"source": [
"L = 64 # Spherical harmonic bandlimit\n",
"rng = np.random.default_rng(1234951510) # Random seed for signal generator\n",
"flm = signal_generator.generate_flm(rng, L, using_torch=True) # Random set of spherical harmonic coefficients"
"L = 64 \n",
"rng = np.random.default_rng(1234951510)\n",
"flm = torch.from_numpy(signal_generator.generate_flm(rng, L))"
]
},
{
Expand All @@ -81,10 +75,18 @@
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
"source": [
"inverse_kernel = spin_spherical_kernel(L, using_torch=True, forward=False) \n",
"forward_kernel = spin_spherical_kernel(L, using_torch=True, forward=True) "
"inverse_kernel = spin_spherical_kernel_torch(L, forward=False) \n",
"forward_kernel = spin_spherical_kernel_torch(L, forward=True) "
]
},
{
Expand Down Expand Up @@ -135,7 +137,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Mean absolute error = 1.1866908936078849e-14\n"
"Mean absolute error = 2.8472981477378884e-14\n"
]
}
],
Expand All @@ -146,7 +148,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.4 ('s2fft')",
"display_name": "s2fft",
"language": "python",
"name": "python3"
},
Expand All @@ -160,14 +162,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
"version": "3.11.10"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
}
}
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ dependencies = [
"numpy>=1.20",
"jax>=0.3.13",
"jaxlib",
"torch",
]
dynamic = [
"version",
Expand Down Expand Up @@ -74,6 +73,10 @@ tests = [
"pytest-cov",
"so3",
"pyssht",
"torch",
]
torch = [
"torch",
]

[tool.scikit-build]
Expand Down
26 changes: 16 additions & 10 deletions s2fft/precompute_transforms/construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
import jax
import jax.numpy as jnp
import numpy as np
import torch

from s2fft import recursions
from s2fft.sampling import s2_samples as samples
from s2fft.utils import quadrature, quadrature_jax
from s2fft.utils import quadrature, quadrature_jax, torch_wrapper

# Maximum spin number at which Price-McEwen recursion is sufficiently accurate.
# For spins > PM_MAX_STABLE_SPIN one should default to the Risbo recursion.
Expand All @@ -22,7 +21,6 @@ def spin_spherical_kernel(
sampling: str = "mw",
nside: int = None,
forward: bool = True,
using_torch: bool = False,
recursion: str = "auto",
) -> np.ndarray:
r"""
Expand Down Expand Up @@ -50,8 +48,6 @@ def spin_spherical_kernel(
forward (bool, optional): Whether to provide forward or inverse shift.
Defaults to False.

using_torch (bool, optional): Desired frontend functionality. Defaults to False.

recursion (str, optional): Recursion to adopt. Supported recursion schemes include
{"auto", "price-mcewen", "risbo"}. Defaults to "auto" which will detect the
most appropriate recursion given the parameter configuration.
Expand Down Expand Up @@ -163,7 +159,7 @@ def spin_spherical_kernel(
healpix_phase_shifts(L, nside, forward)[:, m_start_ind:],
)

return torch.from_numpy(dl) if using_torch else dl
return dl


def spin_spherical_kernel_jax(
Expand Down Expand Up @@ -329,6 +325,11 @@ def spin_spherical_kernel_jax(
return dl


spin_spherical_kernel_torch = torch_wrapper.wrap_as_torch_function(
spin_spherical_kernel_jax
)


def wigner_kernel(
L: int,
N: int,
Expand All @@ -337,7 +338,6 @@ def wigner_kernel(
nside: int = None,
forward: bool = False,
mode: str = "auto",
using_torch: bool = False,
) -> np.ndarray:
r"""
Precompute the wigner-d kernel for Wigner transform.
Expand Down Expand Up @@ -368,8 +368,6 @@ def wigner_kernel(
{"auto", "direct", "fft"}. Defaults to "auto" which will detect the
most appropriate recursion given the parameter configuration.

using_torch (bool, optional): Desired frontend functionality. Defaults to False.

Returns:
np.ndarray: Transform kernel for Wigner transform.

Expand Down Expand Up @@ -468,7 +466,7 @@ def wigner_kernel(
healpix_phase_shifts(L, nside, forward),
)

return torch.from_numpy(dl) if using_torch else dl
return dl


def wigner_kernel_jax(
Expand Down Expand Up @@ -611,6 +609,9 @@ def wigner_kernel_jax(
return dl


wigner_kernel_torch = torch_wrapper.wrap_as_torch_function(wigner_kernel_jax)


def fourier_wigner_kernel(L: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Computes Fourier coefficients of the reduced Wigner d-functions and quadrature
Expand Down Expand Up @@ -667,6 +668,11 @@ def fourier_wigner_kernel_jax(L: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
return deltas, w


fourier_wigner_kernel_torch = torch_wrapper.wrap_as_torch_function(
fourier_wigner_kernel_jax
)


def healpix_phase_shifts(L: int, nside: int, forward: bool = False) -> np.ndarray:
r"""
Generates a phase shift vector for HEALPix for all :math:`\theta` rings.
Expand Down
Loading
Loading