Skip to content

Commit e7c68f2

Browse files
committed
pull gl sampling into torch frontend
2 parents 95ea4f5 + 8e111f4 commit e7c68f2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1815
-151
lines changed

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
pip install jaxlib
3030
pip install -r requirements/requirements-core.txt
3131
pip install -r requirements/requirements-docs.txt
32-
pip install .
32+
pip install .\[torch\]
3333
3434
- name: Build Documentation
3535
run: |

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
python -m pip install --upgrade pip
3131
pip install -r requirements/requirements-tests.txt
3232
pip install -r requirements/requirements-core.txt
33-
pip install .
33+
pip install .\[torch\]
3434
3535
- name: Run tests
3636
run: |

.pip_readme.rst

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
.. image:: https://colab.research.google.com/assets/colab-badge.svg
1414
:target: https://colab.research.google.com/drive/1YmJ2ljsF8HBvhPmD4hrYPlyAKc4WPUgq?usp=sharing
1515

16-
Differentiable and accelerated spherical transforms with JAX
16+
Differentiable and accelerated spherical transforms
1717
=================================================================================================================
1818

19-
`S2FFT` is a JAX package for computing Fourier transforms on the sphere
20-
and rotation group. It leverages autodiff to provide differentiable
19+
`S2FFT` is a Python package for computing Fourier transforms on the sphere
20+
and rotation group using JAX and PyTorch. It leverages autodiff to provide differentiable
2121
transforms, which are also deployable on hardware accelerators
2222
(e.g. GPUs and TPUs).
2323

@@ -27,6 +27,10 @@ for adjoint transformations where needed, and comes with different
2727
optimisations (precompute or not) that one may select depending on
2828
available resources and desired angular resolution $L$.
2929

30+
As of version 1.0.2 `S2FFT` also provides PyTorch implementations of underlying
31+
precompute transforms. In future releases this support will be extended to our
32+
on-the-fly algorithms.
33+
3034
Documentation
3135
=============
3236
Read the full documentation `here <https://astro-informatics.github.io/s2fft/>`_.

README.md

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,23 @@
99

1010
<img align="left" height="85" width="98" src="./docs/assets/sax_logo.png">
1111

12-
# Differentiable and accelerated spherical transforms with JAX
12+
# Differentiable and accelerated spherical transforms
1313

14-
`S2FFT` is a JAX package for computing Fourier transforms on the sphere
15-
and rotation group [(Price & McEwen 2023)](https://arxiv.org/abs/2311.14670). It leverages autodiff to provide differentiable
16-
transforms, which are also deployable on hardware accelerators
17-
(e.g. GPUs and TPUs).
14+
`S2FFT` is a Python package for computing Fourier transforms on the sphere
15+
and rotation group [(Price & McEwen 2023)](https://arxiv.org/abs/2311.14670) using
16+
JAX or PyTorch. It leverages autodiff to provide differentiable transforms, which are
17+
also deployable on hardware accelerators (e.g. GPUs and TPUs).
1818

1919
More specifically, `S2FFT` provides support for spin spherical harmonic
2020
and Wigner transforms (for both real and complex signals), with support
2121
for adjoint transformations where needed, and comes with different
2222
optimisations (precompute or not) that one may select depending on
2323
available resources and desired angular resolution $L$.
2424

25+
As of version 1.0.2 `S2FFT` also provides PyTorch implementations of underlying
26+
precompute transforms. In future releases this support will be extended to our
27+
on-the-fly algorithms.
28+
2529
## Algorithms :zap:
2630

2731
`S2FFT` leverages new algorithmic structures that can he highly
@@ -79,21 +83,33 @@ The Python dependencies for the `S2FFT` package are listed in the file
7983
`requirements/requirements-core.txt` and will be automatically installed
8084
into the active python environment by [pip](https://pypi.org) when running
8185

86+
``` bash
87+
pip install s2fft
88+
```
89+
This will install all core functionality which includes JAX support. To install `S2FFT`
90+
with PyTorch support run
91+
92+
``` bash
93+
pip install s2fft[torch]
94+
```
95+
96+
Alternatively, the `S2FFT` package may be installed directly from GitHub by cloning this
97+
repository and then running
98+
8299
``` bash
83100
pip install .
84101
```
85102

86-
from the root directory of the repository. Unit tests can then be
87-
executed to ensure the installation was successful by running
103+
from the root directory of the repository. To enable PyTorch support you will need to run
88104

89105
``` bash
90-
pytest tests/
106+
pip install .[torch]
91107
```
92108

93-
Alternatively, the `S2FFT` package may be installed directly from PyPi by running
109+
Unit tests can then be executed to ensure the installation was successful by running
94110

95111
``` bash
96-
pip install s2fft
112+
pytest tests/
97113
```
98114

99115
> [!NOTE]
@@ -123,6 +139,9 @@ f = fft.wigner.inverse_jax(flmn, L, N)
123139

124140
For further details on usage see the [documentation](https://astro-informatics.github.io/s2fft/) and associated [notebooks](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/spherical_harmonic_transform.html).
125141

142+
> [!NOTE]
143+
> We also provide PyTorch support for the precompute version of our transforms. These are called through forward/inverse_torch(). Full PyTorch support will be provided in future releases.
144+
126145
## Benchmarking :hourglass_flowing_sand:
127146

128147
We benchmarked the spherical harmonic and Wigner transforms implemented

docs/api/precompute_transforms/index.rst

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,21 @@ Precompute Functions
1111
* - Function Name
1212
- Description
1313
* - :func:`~s2fft.precompute_transforms.spherical.inverse`
14-
- Wrapper function around NumPy/JAX inverse methods
14+
- Wrapper function around NumPy/JAX/Torch inverse methods
1515
* - :func:`~s2fft.precompute_transforms.spherical.inverse_transform`
1616
- Inverse spherical harmonic transform (NumPy)
1717
* - :func:`~s2fft.precompute_transforms.spherical.inverse_transform_jax`
1818
- Inverse spherical harmonic transform (JAX)
19+
* - :func:`~s2fft.precompute_transforms.spherical.inverse_transform_torch`
20+
- Inverse spherical harmonic transform (Torch)
1921
* - :func:`~s2fft.precompute_transforms.spherical.forward`
20-
- Wrapper function around NumPy/JAX forward methods
22+
- Wrapper function around NumPy/JAX/Torch forward methods
2123
* - :func:`~s2fft.precompute_transforms.spherical.forward_transform`
2224
- Forward spherical harmonic transform (NumPy)
2325
* - :func:`~s2fft.precompute_transforms.spherical.forward_transform_jax`
2426
- Forward spherical harmonic transform (JAX)
27+
* - :func:`~s2fft.precompute_transforms.spherical.forward_transform_torch`
28+
- Forward spherical harmonic transform (Torch)
2529

2630
.. list-table:: Wigner transforms.
2731
:widths: 25 25
@@ -30,17 +34,21 @@ Precompute Functions
3034
* - Function Name
3135
- Description
3236
* - :func:`~s2fft.precompute_transforms.wigner.inverse`
33-
- Wrapper function around NumPy/JAX inverse methods
37+
- Wrapper function around NumPy/JAX/Torch inverse methods
3438
* - :func:`~s2fft.precompute_transforms.wigner.inverse_transform`
3539
- Inverse Wigner transform (NumPy)
3640
* - :func:`~s2fft.precompute_transforms.wigner.inverse_transform_jax`
3741
- Inverse Wigner transform (JAX)
42+
* - :func:`~s2fft.precompute_transforms.wigner.inverse_transform_torch`
43+
- Inverse Wigner transform (Torch)
3844
* - :func:`~s2fft.precompute_transforms.wigner.forward`
39-
- Wrapper function around NumPy/JAX forward methods
45+
- Wrapper function around NumPy/JAX/Torch forward methods
4046
* - :func:`~s2fft.precompute_transforms.wigner.forward_transform`
4147
- Forward Wigner transform (NumPy)
4248
* - :func:`~s2fft.precompute_transforms.wigner.forward_transform_jax`
4349
- Forward Wigner transform (JAX)
50+
* - :func:`~s2fft.precompute_transforms.wigner.forward_transform_torch`
51+
- Forward Wigner transform (Torch)
4452

4553
.. list-table:: Constructing Kernels for precompute transforms.
4654
:widths: 25 25

docs/api/utility/index.rst

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,28 @@ Utility Functions
1616
- Computes the Inverse Fast Fourier Transform with spectral folding in the polar regions to mitigate aliasing (NumPy).
1717
* - :func:`~s2fft.utils.healpix_ffts.healpix_ifft_jax`
1818
- Computes the Inverse Fast Fourier Transform with spectral folding in the polar regions to mitigate aliasing (JAX).
19+
* - :func:`~s2fft.utils.healpix_ffts.healpix_ifft_torch`
20+
- Computes the Inverse Fast Fourier Transform with spectral folding in the polar regions to mitigate aliasing (Torch).
1921
* - :func:`~s2fft.utils.healpix_ffts.healpix_fft`
2022
- Wrapper function for the Forward Fast Fourier Transform with spectral back-projection in the polar regions to manually enforce Fourier periodicity.
2123
* - :func:`~s2fft.utils.healpix_ffts.healpix_fft_numpy`
2224
- Computes the Forward Fast Fourier Transform with spectral back-projection in the polar regions (NumPy).
2325
* - :func:`~s2fft.utils.healpix_ffts.healpix_fft_jax`
24-
- Computes the Forward Fast Fourier Transform with spectral back-projection in the polar regions (NumPy).
26+
- Computes the Forward Fast Fourier Transform with spectral back-projection in the polar regions (JAX).
27+
* - :func:`~s2fft.utils.healpix_ffts.healpix_fft_torch`
28+
- Computes the Forward Fast Fourier Transform with spectral back-projection in the polar regions (Torch).
2529
* - :func:`~s2fft.utils.healpix_ffts.spectral_folding`
2630
- Folds higher frequency Fourier coefficients back onto lower frequency coefficients (NumPy).
2731
* - :func:`~s2fft.utils.healpix_ffts.spectral_folding_jax`
2832
- Folds higher frequency Fourier coefficients back onto lower frequency coefficients (JAX).
33+
* - :func:`~s2fft.utils.healpix_ffts.spectral_folding_torch`
34+
- Folds higher frequency Fourier coefficients back onto lower frequency coefficients (Torch).
2935
* - :func:`~s2fft.utils.healpix_ffts.spectral_periodic_extension`
3036
- Extends lower frequency Fourier coefficients onto higher frequency coefficients (NumPy).
3137
* - :func:`~s2fft.utils.healpix_ffts.spectral_periodic_extension_jax`
3238
- Extends lower frequency Fourier coefficients onto higher frequency coefficients (JAX).
39+
* - :func:`~s2fft.utils.healpix_ffts.spectral_periodic_extension_torch`
40+
- Extends lower frequency Fourier coefficients onto higher frequency coefficients (Torch).
3341

3442

3543
.. list-table:: Quadrature functions.
@@ -61,8 +69,9 @@ Utility Functions
6169

6270
.. note::
6371

64-
JAX versions of these functions share an almost identical function trace and
65-
are simply accessed by the sub-module :func:`~s2fft.utils.quadrature_jax`.
72+
JAX and Torch versions of these functions share an almost identical function trace and
73+
are simply accessed by the sub-modules :func:`~s2fft.utils.quadrature_jax` and
74+
:func:`~s2fft.utils.quadrature_torch` respectively.
6675

6776
.. list-table:: Periodic resampling functions
6877
:widths: 25 25
@@ -102,8 +111,9 @@ Utility Functions
102111

103112
.. note::
104113

105-
JAX versions of these functions share an almost identical function trace and
106-
are simply accessed by the sub-module :func:`~s2fft.utils.resampling_jax`.
114+
JAX and Torch versions of these functions share an almost identical function trace and
115+
are simply accessed by the sub-modules :func:`~s2fft.utils.resampling_jax` and
116+
:func:`~s2fft.utils.resampling_torch` respectively.
107117

108118
.. list-table:: Rotation functions
109119
:widths: 25 25
@@ -124,8 +134,10 @@ Utility Functions
124134
signal_generator
125135
resampling
126136
resampling_jax
137+
resampling_torch
127138
quadrature
128139
quadrature_jax
140+
quadrature_torch
129141
healpix_ffts
130142
utils
131143
rotation
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:html_theme.sidebar_secondary.remove:
2+
3+
**************************
4+
quadrature_torch
5+
**************************
6+
.. automodule:: s2fft.utils.quadrature_torch
7+
:members:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:html_theme.sidebar_secondary.remove:
2+
3+
**************************
4+
resampling_torch
5+
**************************
6+
.. automodule:: s2fft.utils.resampling_torch
7+
:members:

docs/conf.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
author = "Matthew Price, Jason McEwen, Matthew Graham, Sofia Miñano, Devaraj Gopinathan"
2626

2727
# The short X.Y version
28-
version = "1.0.1"
28+
version = "1.0.2"
2929
# The full version, including alpha/beta/rc tags
30-
release = "1.0.1"
30+
release = "1.0.2"
3131

3232

3333
# -- General configuration ---------------------------------------------------
@@ -112,12 +112,12 @@
112112
# "icon": "fa-brands fa-youtube fa-2x",
113113
# "type": "fontawesome",
114114
# },
115-
# {
116-
# "name": "PyPi",
117-
# "url": "https://github.com/astro-informatics/s2fft/",
118-
# "icon": "_static/pypi.png",
119-
# "type": "local",
120-
# },
115+
{
116+
"name": "PyPi",
117+
"url": "https://pypi.org/project/s2fft/",
118+
"icon": "_static/pypi.png",
119+
"type": "local",
120+
},
121121
{
122122
"name": "GitHub",
123123
"url": "https://github.com/astro-informatics/s2fft/",

0 commit comments

Comments
 (0)