Skip to content

Commit f17ac17

Browse files
authored
Merge pull request #185 from astro-informatics/features/pytorch_precompute_transforms
add pytorch support for precompute transform
2 parents b82e065 + 1cadd79 commit f17ac17

26 files changed

+1638
-68
lines changed

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
pip install jaxlib
3030
pip install -r requirements/requirements-core.txt
3131
pip install -r requirements/requirements-docs.txt
32-
pip install .
32+
pip install .\[torch\]
3333
3434
- name: Build Documentation
3535
run: |

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
python -m pip install --upgrade pip
3131
pip install -r requirements/requirements-tests.txt
3232
pip install -r requirements/requirements-core.txt
33-
pip install .
33+
pip install .\[torch\]
3434
3535
- name: Run tests
3636
run: |

.pip_readme.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ for adjoint transformations where needed, and comes with different
2727
optimisations (precompute or not) that one may select depending on
2828
available resources and desired angular resolution $L$.
2929

30+
As of version 1.0.2 `S2FFT` also provides PyTorch implementations of underlying
31+
precompute transforms. In future releases this support will be extended to our
32+
on-the-fly algorithms.
33+
3034
Documentation
3135
=============
3236
Read the full documentation `here <https://astro-informatics.github.io/s2fft/>`_.

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ for adjoint transformations where needed, and comes with different
2222
optimisations (precompute or not) that one may select depending on
2323
available resources and desired angular resolution $L$.
2424

25+
As of version 1.0.2 `S2FFT` also provides PyTorch implementations of underlying
26+
precompute transforms. In future releases this support will be extended to our
27+
on-the-fly algorithms.
28+
2529
## Algorithms :zap:
2630

2731
`S2FFT` leverages new algorithmic structures that can he highly
@@ -123,6 +127,9 @@ f = fft.wigner.inverse_jax(flmn, L, N)
123127

124128
For further details on usage see the [documentation](https://astro-informatics.github.io/s2fft/) and associated [notebooks](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/spherical_harmonic_transform.html).
125129

130+
> [!NOTE]
131+
> We also provide PyTorch support for the precompute version of our transforms. These are called through forward/inverse_torch(). Full PyTorch support will be provided in future releases.
132+
126133
## Benchmarking :hourglass_flowing_sand:
127134

128135
We benchmarked the spherical harmonic and Wigner transforms implemented

docs/api/precompute_transforms/index.rst

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,21 @@ Precompute Functions
1111
* - Function Name
1212
- Description
1313
* - :func:`~s2fft.precompute_transforms.spherical.inverse`
14-
- Wrapper function around NumPy/JAX inverse methods
14+
- Wrapper function around NumPy/JAX/Torch inverse methods
1515
* - :func:`~s2fft.precompute_transforms.spherical.inverse_transform`
1616
- Inverse spherical harmonic transform (NumPy)
1717
* - :func:`~s2fft.precompute_transforms.spherical.inverse_transform_jax`
1818
- Inverse spherical harmonic transform (JAX)
19+
* - :func:`~s2fft.precompute_transforms.spherical.inverse_transform_torch`
20+
- Inverse spherical harmonic transform (Torch)
1921
* - :func:`~s2fft.precompute_transforms.spherical.forward`
20-
- Wrapper function around NumPy/JAX forward methods
22+
- Wrapper function around NumPy/JAX/Torch forward methods
2123
* - :func:`~s2fft.precompute_transforms.spherical.forward_transform`
2224
- Forward spherical harmonic transform (NumPy)
2325
* - :func:`~s2fft.precompute_transforms.spherical.forward_transform_jax`
2426
- Forward spherical harmonic transform (JAX)
27+
* - :func:`~s2fft.precompute_transforms.spherical.forward_transform_torch`
28+
- Forward spherical harmonic transform (Torch)
2529

2630
.. list-table:: Wigner transforms.
2731
:widths: 25 25
@@ -30,17 +34,21 @@ Precompute Functions
3034
* - Function Name
3135
- Description
3236
* - :func:`~s2fft.precompute_transforms.wigner.inverse`
33-
- Wrapper function around NumPy/JAX inverse methods
37+
- Wrapper function around NumPy/JAX/Torch inverse methods
3438
* - :func:`~s2fft.precompute_transforms.wigner.inverse_transform`
3539
- Inverse Wigner transform (NumPy)
3640
* - :func:`~s2fft.precompute_transforms.wigner.inverse_transform_jax`
3741
- Inverse Wigner transform (JAX)
42+
* - :func:`~s2fft.precompute_transforms.wigner.inverse_transform_torch`
43+
- Inverse Wigner transform (Torch)
3844
* - :func:`~s2fft.precompute_transforms.wigner.forward`
39-
- Wrapper function around NumPy/JAX forward methods
45+
- Wrapper function around NumPy/JAX/Torch forward methods
4046
* - :func:`~s2fft.precompute_transforms.wigner.forward_transform`
4147
- Forward Wigner transform (NumPy)
4248
* - :func:`~s2fft.precompute_transforms.wigner.forward_transform_jax`
4349
- Forward Wigner transform (JAX)
50+
* - :func:`~s2fft.precompute_transforms.wigner.forward_transform_torch`
51+
- Forward Wigner transform (Torch)
4452

4553
.. list-table:: Constructing Kernels for precompute transforms.
4654
:widths: 25 25

docs/api/utility/index.rst

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,28 @@ Utility Functions
1616
- Computes the Inverse Fast Fourier Transform with spectral folding in the polar regions to mitigate aliasing (NumPy).
1717
* - :func:`~s2fft.utils.healpix_ffts.healpix_ifft_jax`
1818
- Computes the Inverse Fast Fourier Transform with spectral folding in the polar regions to mitigate aliasing (JAX).
19+
* - :func:`~s2fft.utils.healpix_ffts.healpix_ifft_torch`
20+
- Computes the Inverse Fast Fourier Transform with spectral folding in the polar regions to mitigate aliasing (Torch).
1921
* - :func:`~s2fft.utils.healpix_ffts.healpix_fft`
2022
- Wrapper function for the Forward Fast Fourier Transform with spectral back-projection in the polar regions to manually enforce Fourier periodicity.
2123
* - :func:`~s2fft.utils.healpix_ffts.healpix_fft_numpy`
2224
- Computes the Forward Fast Fourier Transform with spectral back-projection in the polar regions (NumPy).
2325
* - :func:`~s2fft.utils.healpix_ffts.healpix_fft_jax`
24-
- Computes the Forward Fast Fourier Transform with spectral back-projection in the polar regions (NumPy).
26+
- Computes the Forward Fast Fourier Transform with spectral back-projection in the polar regions (JAX).
27+
* - :func:`~s2fft.utils.healpix_ffts.healpix_fft_torch`
28+
- Computes the Forward Fast Fourier Transform with spectral back-projection in the polar regions (Torch).
2529
* - :func:`~s2fft.utils.healpix_ffts.spectral_folding`
2630
- Folds higher frequency Fourier coefficients back onto lower frequency coefficients (NumPy).
2731
* - :func:`~s2fft.utils.healpix_ffts.spectral_folding_jax`
2832
- Folds higher frequency Fourier coefficients back onto lower frequency coefficients (JAX).
33+
* - :func:`~s2fft.utils.healpix_ffts.spectral_folding_torch`
34+
- Folds higher frequency Fourier coefficients back onto lower frequency coefficients (Torch).
2935
* - :func:`~s2fft.utils.healpix_ffts.spectral_periodic_extension`
3036
- Extends lower frequency Fourier coefficients onto higher frequency coefficients (NumPy).
3137
* - :func:`~s2fft.utils.healpix_ffts.spectral_periodic_extension_jax`
3238
- Extends lower frequency Fourier coefficients onto higher frequency coefficients (JAX).
39+
* - :func:`~s2fft.utils.healpix_ffts.spectral_periodic_extension_torch`
40+
- Extends lower frequency Fourier coefficients onto higher frequency coefficients (Torch).
3341

3442

3543
.. list-table:: Quadrature functions.
@@ -61,8 +69,9 @@ Utility Functions
6169

6270
.. note::
6371

64-
JAX versions of these functions share an almost identical function trace and
65-
are simply accessed by the sub-module :func:`~s2fft.utils.quadrature_jax`.
72+
JAX and Torch versions of these functions share an almost identical function trace and
73+
are simply accessed by the sub-modules :func:`~s2fft.utils.quadrature_jax` and
74+
:func:`~s2fft.utils.quadrature_torch` respectively.
6675

6776
.. list-table:: Periodic resampling functions
6877
:widths: 25 25
@@ -102,8 +111,9 @@ Utility Functions
102111

103112
.. note::
104113

105-
JAX versions of these functions share an almost identical function trace and
106-
are simply accessed by the sub-module :func:`~s2fft.utils.resampling_jax`.
114+
JAX and Torch versions of these functions share an almost identical function trace and
115+
are simply accessed by the sub-modules :func:`~s2fft.utils.resampling_jax` and
116+
:func:`~s2fft.utils.resampling_torch` respectively.
107117

108118
.. list-table:: Rotation functions
109119
:widths: 25 25
@@ -124,8 +134,10 @@ Utility Functions
124134
signal_generator
125135
resampling
126136
resampling_jax
137+
resampling_torch
127138
quadrature
128139
quadrature_jax
140+
quadrature_torch
129141
healpix_ffts
130142
utils
131143
rotation
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:html_theme.sidebar_secondary.remove:
2+
3+
**************************
4+
quadrature_torch
5+
**************************
6+
.. automodule:: s2fft.utils.quadrature_torch
7+
:members:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:html_theme.sidebar_secondary.remove:
2+
3+
**************************
4+
resampling_torch
5+
**************************
6+
.. automodule:: s2fft.utils.resampling_torch
7+
:members:

docs/conf.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
author = "Matthew Price, Jason McEwen, Matthew Graham, Sofia Miñano, Devaraj Gopinathan"
2626

2727
# The short X.Y version
28-
version = "1.0.1"
28+
version = "1.0.2"
2929
# The full version, including alpha/beta/rc tags
30-
release = "1.0.1"
30+
release = "1.0.2"
3131

3232

3333
# -- General configuration ---------------------------------------------------
@@ -112,12 +112,12 @@
112112
# "icon": "fa-brands fa-youtube fa-2x",
113113
# "type": "fontawesome",
114114
# },
115-
# {
116-
# "name": "PyPi",
117-
# "url": "https://github.com/astro-informatics/s2fft/",
118-
# "icon": "_static/pypi.png",
119-
# "type": "local",
120-
# },
115+
{
116+
"name": "PyPi",
117+
"url": "https://pypi.org/project/s2fft/",
118+
"icon": "_static/pypi.png",
119+
"type": "local",
120+
},
121121
{
122122
"name": "GitHub",
123123
"url": "https://github.com/astro-informatics/s2fft/",

docs/index.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ transforms (for both real and complex signals), with support for adjoint transfo
1010
where needed, and comes with different optimisations (precompute or not) that one
1111
may select depending on available resources and desired angular resolution :math:`L`.
1212

13+
As of version 1.0.2 ``S2FFT`` also provides PyTorch implementations of underlying
14+
precompute transforms. In future releases this support will be extended to our
15+
on-the-fly algorithms.
16+
1317
Algorithms |:zap:|
1418
-------------------
1519

0 commit comments

Comments
 (0)