Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
b039456
Make spherical precompute benchmarks compatible with pytorch
matt-graham Feb 24, 2025
5333b69
Add utilities for wrapping JAX function for use in PyTorch
matt-graham Feb 26, 2025
34459f0
Add initial wrapped versions of torch precompute transforms
matt-graham Feb 26, 2025
2801dfc
Update array conversion in benchmarks to avoid byte alignment warning
matt-graham Mar 6, 2025
8092bda
Remove previous torch precompute spherical transform implementations
matt-graham Mar 6, 2025
9c8cef7
Make precompute Wigner benchmarks compatible with torch
matt-graham Mar 6, 2025
1b4dc7e
Add torch wrapper precompute Wigner transforms
matt-graham Mar 6, 2025
d491426
Removing docstring for now removed using_torch arg
matt-graham Mar 6, 2025
421229c
Correct typo in docstring
matt-graham Mar 6, 2025
c020be2
Remove previous torch precompute Wigner transform implementations
matt-graham Mar 6, 2025
76b02b3
Update references to JAX in docstring when wrapping for torch use
matt-graham Mar 6, 2025
e7aa2fe
Try to infer differentiable args from annotations
matt-graham Mar 11, 2025
cc7d5c0
Update annotations of wrapped functions
matt-graham Mar 11, 2025
59ae417
Remove explicit differentiable argument name defs
matt-graham Mar 11, 2025
c2d55f1
Add helper function for wrapping all JAX functions in module
matt-graham Mar 11, 2025
1ebd797
Use wrappers for Torch resampling and quadrature modules
matt-graham Mar 11, 2025
b96c530
Remove using_torch option from signal generator functions
matt-graham Mar 11, 2025
5fde7ad
Update torch demo notebook to follow new wrapper interface
matt-graham Mar 11, 2025
6761781
Use wrappers for torch HEALPix FFT functions
matt-graham Mar 11, 2025
a0fe321
Make torch an optional dependency
matt-graham Mar 11, 2025
088dd8d
Use backwards compatible tree_map import
matt-graham Mar 11, 2025
bc440df
Copy annotations in wrapped function and check for existence of doc
matt-graham Mar 20, 2025
b24c174
Start on torch wrapper tests
matt-graham Mar 20, 2025
1ece73b
Add annotations futures import
matt-graham Mar 28, 2025
bd49bf7
Add additional torch wrapper tests
matt-graham Mar 28, 2025
8466258
Merge branch 'main' into mmg/pytorch-wrapper
matt-graham Apr 3, 2025
b1b7207
Make type alias Python 3.8 compatible
matt-graham Apr 4, 2025
935ad4b
Account for differing complex derivatives conventions between torch a…
matt-graham Apr 4, 2025
9532ca0
Add additional complex test cases and gradient checks to wrapper tests
matt-graham Apr 4, 2025
5c951df
Ignore complex warning due to casts in tests rather than erroring
matt-graham Apr 4, 2025
4d01933
Reduce max number iter tested for HEALPix to reduce test times
matt-graham Apr 4, 2025
a1b5d0b
Maintain compatibility with older JAX versions
matt-graham Apr 4, 2025
8d9a2ff
More maintaining compatibility with older JAX versions
matt-graham Apr 4, 2025
aff48ac
Correct typo in comment
matt-graham Apr 4, 2025
0d4698f
Explicitly cast kernels in einsum ops to avoid ComplexWarning causing…
matt-graham Apr 4, 2025
4990ca4
Force JAX double precision mode in Wigner precompute tests
matt-graham Apr 4, 2025
7208d9f
Add test for function checking torch available
matt-graham Apr 8, 2025
1c30841
Fix torch optional import logic to avoid errors when not installed
matt-graham Apr 11, 2025
c1ec18a
Refactor method dispatch logic in inverse transform
matt-graham Apr 16, 2025
9dd1656
Refactor method dispatch logic in HEALPix FFTs
matt-graham Apr 16, 2025
bf80707
Expose option to use HEALPix custom primitive in inverse transform
matt-graham Apr 16, 2025
addd2ce
Pass through method to select HEALPix (I)FFT function
matt-graham Apr 23, 2025
9efb251
Expose jax_cuda method in top-level spherical inverse function
matt-graham Apr 23, 2025
09b75d5
Refactor method dispatch logic in forward transform
matt-graham Apr 23, 2025
1715ead
Add OTF spherical transform torch wrappers
matt-graham Apr 23, 2025
a7c548f
Make torch wrapper diff arg inferring robust to non-type annotations
matt-graham Apr 23, 2025
364a6f6
Mark use_healpix_custom_primitive arg as static
matt-graham Apr 23, 2025
ff8513d
Include torch wrappers in tests
matt-graham Apr 23, 2025
cd3988d
Refactor method dispatch logic in Wigner transforms
matt-graham Apr 23, 2025
034ecbd
Add torch wrappers for Wigner OTF transforms
matt-graham Apr 23, 2025
6444bb0
Update README and notebook to indicate wider torch support
matt-graham Apr 23, 2025
a91729f
Pin JAX version to less than v0.6.0 due to breaking changes
matt-graham Apr 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ f = fft.wigner.inverse_jax(flmn, L, N, method="jax")
For further details on usage see the [documentation](https://astro-informatics.github.io/s2fft/) and associated [notebooks](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/spherical_harmonic_transform.html).

> [!NOTE]
> We also provide PyTorch support for the precompute version of our transforms, as demonstrated in the [_Torch frontend_ tutorial notebook](https://astro-informatics.github.io/s2fft/tutorials/torch_frontend/torch_frontend.html).
> We also provide PyTorch support for our transforms, as demonstrated in the [_Torch frontend_ tutorial notebook](https://astro-informatics.github.io/s2fft/tutorials/torch_frontend/torch_frontend.html). This wraps the JAX implementations so JAX will need to be installed in addition to PyTorch.

## SSHT & HEALPix wrappers 💡

Expand Down
33 changes: 23 additions & 10 deletions benchmarks/precompute_spherical.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Benchmarks for precompute spherical transforms."""

import jax
import numpy as np
from benchmarking import (
BenchmarkSetup,
Expand All @@ -10,6 +11,7 @@

import s2fft
import s2fft.precompute_transforms
from s2fft.utils import torch_wrapper

L_VALUES = [8, 16, 32, 64, 128, 256]
SPIN_VALUES = [0]
Expand All @@ -31,11 +33,17 @@ def setup_forward(method, L, sampling, spin, reality, recursion):
sampling=sampling,
reality=reality,
)
kernel_function = (
s2fft.precompute_transforms.construct.spin_spherical_kernel_jax
if method == "jax"
else s2fft.precompute_transforms.construct.spin_spherical_kernel
)
# As torch method wraps JAX functions and converting NumPy array to torch tensor
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
# converting to JAX array using mutual DLPack support we first convert the NumPy
# arrays to a JAX arrays before converting to torch tensors which eliminates this
# warning
if method.startswith("jax") or method.startswith("torch"):
flm = jax.numpy.asarray(flm)
f = jax.numpy.asarray(f)
if method.startswith("torch"):
flm, f = torch_wrapper.tree_map_jax_array_to_torch_tensor((flm, f))
kernel_function = s2fft.precompute_transforms.spherical._kernel_functions[method]
kernel = kernel_function(
L=L,
spin=spin,
Expand Down Expand Up @@ -73,11 +81,16 @@ def setup_inverse(method, L, sampling, spin, reality, recursion):
skip("Reality only valid for scalar fields (spin=0).")
rng = np.random.default_rng()
flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality)
kernel_function = (
s2fft.precompute_transforms.construct.spin_spherical_kernel_jax
if method == "jax"
else s2fft.precompute_transforms.construct.spin_spherical_kernel
)
# As torch method wraps JAX functions and converting NumPy array to torch tensor
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
# converting to JAX array using mutual DLPack support we first convert the NumPy
# array to a JAX array before converting to a torch tensor which eliminates this
# warning
if method.startswith("jax") or method.startswith("torch"):
flm = jax.numpy.asarray(flm)
if method.startswith("torch"):
flm = torch_wrapper.jax_array_to_torch_tensor(flm)
kernel_function = s2fft.precompute_transforms.spherical._kernel_functions[method]
kernel = kernel_function(
L=L,
spin=spin,
Expand Down
33 changes: 23 additions & 10 deletions benchmarks/precompute_wigner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Benchmarks for precompute Wigner-d transforms."""

import jax
import numpy as np
from benchmarking import (
BenchmarkSetup,
Expand All @@ -10,6 +11,7 @@
import s2fft
import s2fft.precompute_transforms
from s2fft.base_transforms import wigner as base_wigner
from s2fft.utils import torch_wrapper

L_VALUES = [16, 32, 64, 128, 256]
N_VALUES = [2]
Expand All @@ -31,11 +33,17 @@ def setup_forward(method, L, N, L_lower, sampling, reality, mode):
sampling=sampling,
reality=reality,
)
kernel_function = (
s2fft.precompute_transforms.construct.wigner_kernel_jax
if "jax" in method
else s2fft.precompute_transforms.construct.wigner_kernel
)
# As torch method wraps JAX functions and converting NumPy array to torch tensor
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
# converting to JAX array using mutual DLPack support we first convert the NumPy
# arrays to a JAX arrays before converting to torch tensors which eliminates this
# warning
if method.startswith("jax") or method.startswith("torch"):
flmn = jax.numpy.asarray(flmn)
f = jax.numpy.asarray(f)
if method.startswith("torch"):
flmn, f = torch_wrapper.tree_map_jax_array_to_torch_tensor((flmn, f))
kernel_function = s2fft.precompute_transforms.wigner._kernel_functions[method]
kernel = kernel_function(
L=L, N=N, reality=reality, sampling=sampling, forward=True, mode=mode
)
Expand Down Expand Up @@ -67,11 +75,16 @@ def forward(f, kernel, method, L, N, L_lower, sampling, reality, mode):
def setup_inverse(method, L, N, L_lower, sampling, reality, mode):
rng = np.random.default_rng()
flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality)
kernel_function = (
s2fft.precompute_transforms.construct.wigner_kernel_jax
if method == "jax"
else s2fft.precompute_transforms.construct.wigner_kernel
)
# As torch method wraps JAX functions and converting NumPy array to torch tensor
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
# converting to JAX array using mutual DLPack support we first convert the NumPy
# arrays to a JAX arrays before converting to torch tensors which eliminates this
# warning
if method.startswith("jax") or method.startswith("torch"):
flmn = jax.numpy.asarray(flmn)
if method.startswith("torch"):
flmn = torch_wrapper.jax_array_to_torch_tensor(flmn)
kernel_function = s2fft.precompute_transforms.wigner._kernel_functions[method]
kernel = kernel_function(
L=L, N=N, reality=reality, sampling=sampling, forward=False, mode=mode
)
Expand Down
125 changes: 89 additions & 36 deletions notebooks/torch_frontend.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,24 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"This minimal tutorial demonstrates how to use the torch frontend for `S2FFT` to compute spherical harmonic transforms. Though `S2FFT` is primarily designed for JAX, this torch functionality is fully unit tested (including gradients) and can be used straightforwardly as a learnable layer within existing models."
"This minimal tutorial demonstrates how to use the torch frontend for `S2FFT` to compute spherical harmonic transforms. Though `S2FFT` is primarily designed for JAX, this torch functionality is fully unit tested (including gradients) and can be used straightforwardly as a learnable layer within existing models. As the torch functions wrap the JAX implementations we need to configure JAX to use 64-bit precision floating point types by default to ensure sufficient precision for the transforms - `S2FFT` will emit a warning if this has not been done."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L.\n"
]
}
],
"outputs": [],
"source": [
"import jax\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"import torch \n",
"import numpy as np \n",
"from s2fft.precompute_transforms.spherical import inverse, forward\n",
"from s2fft.precompute_transforms.construct import spin_spherical_kernel\n",
"import numpy as np\n",
"from s2fft.transforms.spherical import inverse, forward\n",
"from s2fft.precompute_transforms.spherical import (\n",
" inverse as precompute_inverse, forward as precompute_forward\n",
")\n",
"from s2fft.precompute_transforms.construct import spin_spherical_kernel_torch\n",
"from s2fft.utils import signal_generator"
]
},
Expand All @@ -65,33 +62,40 @@
"metadata": {},
"outputs": [],
"source": [
"L = 64 # Spherical harmonic bandlimit\n",
"rng = np.random.default_rng(1234951510) # Random seed for signal generator\n",
"flm = signal_generator.generate_flm(rng, L, using_torch=True) # Random set of spherical harmonic coefficients"
"L = 64 \n",
"rng = np.random.default_rng(1234951510)\n",
"flm = torch.from_numpy(signal_generator.generate_flm(rng, L))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For the fully precompute transform we must also generate the precompute kernels which we store as a torch tensors."
"Now lets calculate the signal on the sphere by applying the inverse spherical harmonic transform"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
"source": [
"inverse_kernel = spin_spherical_kernel(L, using_torch=True, forward=False) \n",
"forward_kernel = spin_spherical_kernel(L, using_torch=True, forward=True) "
"f = inverse(flm, L, method=\"torch\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now lets calculate the signal on the sphere by applying the inverse spherical harmonic transform"
"To calculate the corresponding spherical harmonic representation execute"
]
},
{
Expand All @@ -100,53 +104,107 @@
"metadata": {},
"outputs": [],
"source": [
"f = inverse(flm, L, 0, inverse_kernel, method=\"torch\")"
"flm_check = forward(f, L, method=\"torch\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To calculate the corresponding spherical harmonic representation execute"
"Finally, lets check the error on the round trip is as expected for 64 bit machine precision floating point arithmetic"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean absolute error = 2.8915048238993476e-14\n"
]
}
],
"source": [
"flm_check = forward(f, L, 0, forward_kernel, method=\"torch\")"
"print(f\"Mean absolute error = {np.nanmean(np.abs(flm_check - flm))}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, lets check the error on the roundtrip is at 64bit machine precision"
"For the fully precompute transform we must also generate the precompute kernels which we store as a torch tensors."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"inverse_kernel = spin_spherical_kernel_torch(L, forward=False) \n",
"forward_kernel = spin_spherical_kernel_torch(L, forward=True) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We then pass the kernels as additional arguments to the transform functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'orward_kernel' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[8], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m precompute_f \u001b[38;5;241m=\u001b[39m precompute_inverse(flm, L, kernel\u001b[38;5;241m=\u001b[39minverse_kernel, method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m precompute_flm_check \u001b[38;5;241m=\u001b[39m precompute_forward(f, L, kernel\u001b[38;5;241m=\u001b[39m\u001b[43morward_kernel\u001b[49m, method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[0;31mNameError\u001b[0m: name 'orward_kernel' is not defined"
]
}
],
"source": [
"precompute_f = precompute_inverse(flm, L, kernel=inverse_kernel, method=\"torch\")\n",
"precompute_flm_check = precompute_forward(f, L, kernel=forward_kernel, method=\"torch\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Again, we check the error on the round trip is as expected"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean absolute error = 1.1866908936078849e-14\n"
"Mean absolute error = 2.8472981477378884e-14\n"
]
}
],
"source": [
"print(f\"Mean absolute error = {np.nanmean(np.abs(flm_check - flm))}\")"
"print(f\"Mean absolute error = {np.nanmean(np.abs(precompute_flm_check - flm))}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.4 ('s2fft')",
"display_name": "s2fft",
"language": "python",
"name": "python3"
},
Expand All @@ -160,14 +218,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
"version": "3.11.10"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
}
}
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ classifiers = [
description = "Differentiable and accelerated spherical transforms with JAX"
dependencies = [
"numpy>=1.20",
"jax>=0.3.13",
"jax>=0.3.13,<0.6.0",
"jaxlib",
"torch",
]
dynamic = [
"version",
Expand Down Expand Up @@ -74,6 +73,10 @@ tests = [
"pytest-cov",
"so3",
"pyssht",
"torch",
]
torch = [
"torch",
]

[tool.scikit-build]
Expand Down
Loading
Loading