Skip to content

Commit c3ba92d

Browse files
committed
update test install for [torch], add torch notebook, update docs
1 parent bef6d34 commit c3ba92d

File tree

6 files changed

+194
-6
lines changed

6 files changed

+194
-6
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
python -m pip install --upgrade pip
3131
pip install -r requirements/requirements-tests.txt
3232
pip install -r requirements/requirements-core.txt
33-
pip install .
33+
pip install .\[torch\]
3434
3535
- name: Run tests
3636
run: |

docs/tutorials/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,4 @@ the case for many other methods that scale linearly with spin).
6868
spherical_harmonic/spherical_harmonic_transform.nblink
6969
wigner/wigner_transform.nblink
7070
rotation/rotation.nblink
71+
torch_frontend/torch_frontend.nblink
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"path": "../../../notebooks/torch_frontend.ipynb"
3+
}

notebooks/torch_frontend.ipynb

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# __Torch frontend guide__\n",
8+
"\n",
9+
"---\n",
10+
"\n"
11+
]
12+
},
13+
{
14+
"cell_type": "markdown",
15+
"metadata": {},
16+
"source": [
17+
"This minimal tutorial demonstrates how to use the torch frontend for `S2FFT` to compute spherical harmonic transforms.\n",
18+
"\n",
19+
"Note that to install `S2FFT` with torch functionality from ``PyPi`` run \n",
20+
"``` bash\n",
21+
"pip install s2fft[torch] \n",
22+
"```\n",
23+
"or from source by cloning the repository and running \n",
24+
"``` bash\n",
25+
"pip install .\\[torch\\] \n",
26+
"```\n",
27+
"\n",
28+
"Though `S2FFT` is primarily designed for JAX, this torch functionality is fully unit tested (including gradients) and can be used straightforwardly as a learnable layer within existing models."
29+
]
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": 1,
34+
"metadata": {},
35+
"outputs": [
36+
{
37+
"name": "stderr",
38+
"output_type": "stream",
39+
"text": [
40+
"JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L.\n"
41+
]
42+
}
43+
],
44+
"source": [
45+
"import torch \n",
46+
"import numpy as np \n",
47+
"from s2fft.precompute_transforms.spherical import inverse, forward\n",
48+
"from s2fft.precompute_transforms.construct import spin_spherical_kernel\n",
49+
"from s2fft.utils import signal_generator"
50+
]
51+
},
52+
{
53+
"cell_type": "markdown",
54+
"metadata": {},
55+
"source": [
56+
"Lets set up a mock problem by specifiying a bandlimit $L$ and generating some arbitrary harmonic coefficients."
57+
]
58+
},
59+
{
60+
"cell_type": "code",
61+
"execution_count": 2,
62+
"metadata": {},
63+
"outputs": [],
64+
"source": [
65+
"L = 64 # Spherical harmonic bandlimit\n",
66+
"rng = np.random.default_rng(1234951510) # Random seed for signal generator\n",
67+
"flm = signal_generator.generate_flm(rng, L, using_torch=True) # Random set of spherical harmonic coefficients"
68+
]
69+
},
70+
{
71+
"cell_type": "markdown",
72+
"metadata": {},
73+
"source": [
74+
"For the fully precompute transform we must also generate the precompute kernels which we store as a torch tensors."
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": 3,
80+
"metadata": {},
81+
"outputs": [],
82+
"source": [
83+
"inverse_kernel = spin_spherical_kernel(L, using_torch=True, forward=False) \n",
84+
"forward_kernel = spin_spherical_kernel(L, using_torch=True, forward=True) "
85+
]
86+
},
87+
{
88+
"cell_type": "markdown",
89+
"metadata": {},
90+
"source": [
91+
"Now lets calculate the signal on the sphere by applying the inverse spherical harmonic transform"
92+
]
93+
},
94+
{
95+
"cell_type": "code",
96+
"execution_count": 4,
97+
"metadata": {},
98+
"outputs": [],
99+
"source": [
100+
"f = inverse(flm, L, 0, inverse_kernel, method=\"torch\")"
101+
]
102+
},
103+
{
104+
"cell_type": "markdown",
105+
"metadata": {},
106+
"source": [
107+
"To calculate the corresponding spherical harmonic representation execute"
108+
]
109+
},
110+
{
111+
"cell_type": "code",
112+
"execution_count": 5,
113+
"metadata": {},
114+
"outputs": [],
115+
"source": [
116+
"flm_check = forward(f, L, 0, forward_kernel, method=\"torch\")"
117+
]
118+
},
119+
{
120+
"cell_type": "markdown",
121+
"metadata": {},
122+
"source": [
123+
"Finally, lets check the error on the roundtrip is at 64bit machine precision"
124+
]
125+
},
126+
{
127+
"cell_type": "code",
128+
"execution_count": 6,
129+
"metadata": {},
130+
"outputs": [
131+
{
132+
"name": "stdout",
133+
"output_type": "stream",
134+
"text": [
135+
"Mean absolute error = 1.1866908936078849e-14\n"
136+
]
137+
}
138+
],
139+
"source": [
140+
"print(f\"Mean absolute error = {np.nanmean(np.abs(flm_check - flm))}\")"
141+
]
142+
}
143+
],
144+
"metadata": {
145+
"kernelspec": {
146+
"display_name": "Python 3.10.4 ('s2fft')",
147+
"language": "python",
148+
"name": "python3"
149+
},
150+
"language_info": {
151+
"codemirror_mode": {
152+
"name": "ipython",
153+
"version": 3
154+
},
155+
"file_extension": ".py",
156+
"mimetype": "text/x-python",
157+
"name": "python",
158+
"nbconvert_exporter": "python",
159+
"pygments_lexer": "ipython3",
160+
"version": "3.10.0"
161+
},
162+
"orig_nbformat": 4,
163+
"vscode": {
164+
"interpreter": {
165+
"hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
166+
}
167+
}
168+
},
169+
"nbformat": 4,
170+
"nbformat_minor": 2
171+
}

s2fft/precompute_transforms/construct.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66
import jax.numpy as jnp
7-
7+
import torch
88
from s2fft.sampling import s2_samples as samples
99
from s2fft.utils import quadrature, quadrature_jax
1010
from s2fft import recursions
@@ -18,6 +18,7 @@ def spin_spherical_kernel(
1818
sampling: str = "mw",
1919
nside: int = None,
2020
forward: bool = False,
21+
using_torch: bool = False,
2122
):
2223
r"""Precompute the wigner-d kernel for spin-spherical transform. This can be
2324
drastically faster but comes at a :math:`\mathcal{O}(L^3)` memory overhead, making
@@ -41,6 +42,8 @@ def spin_spherical_kernel(
4142
forward (bool, optional): Whether to provide forward or inverse shift.
4243
Defaults to False.
4344
45+
using_torch (bool, optional): Desired frontend functionality. Defaults to False.
46+
4447
Returns:
4548
np.ndarray: Transform kernel for spin-spherical harmonic transform.
4649
"""
@@ -78,7 +81,7 @@ def spin_spherical_kernel(
7881
healpix_phase_shifts(L, nside, forward)[:, m_start_ind:],
7982
)
8083

81-
return dl
84+
return torch.from_numpy(dl) if using_torch else dl
8285

8386

8487
def spin_spherical_kernel_jax(
@@ -166,6 +169,7 @@ def wigner_kernel(
166169
sampling: str = "mw",
167170
nside: int = None,
168171
forward: bool = False,
172+
using_torch: bool = False,
169173
):
170174
r"""Precompute the wigner-d kernels required for a Wigner transform. This can be
171175
drastically faster but comes at a :math:`\mathcal{O}(NL^3)` memory overhead, making
@@ -189,6 +193,8 @@ def wigner_kernel(
189193
forward (bool, optional): Whether to provide forward or inverse shift.
190194
Defaults to False.
191195
196+
using_torch (bool, optional): Desired frontend functionality. Defaults to False.
197+
192198
Returns:
193199
np.ndarray: Transform kernel for Wigner transform.
194200
"""
@@ -227,7 +233,7 @@ def wigner_kernel(
227233
healpix_phase_shifts(L, nside, forward),
228234
)
229235

230-
return dl
236+
return torch.from_numpy(dl) if using_torch else dl
231237

232238

233239
def wigner_kernel_jax(

s2fft/utils/signal_generator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import torch
23
from s2fft.sampling import s2_samples as samples
34
from s2fft.sampling import so3_samples as wigner_samples
45

@@ -9,6 +10,7 @@ def generate_flm(
910
L_lower: int = 0,
1011
spin: int = 0,
1112
reality: bool = False,
13+
using_torch: bool = False,
1214
) -> np.ndarray:
1315
r"""Generate a 2D set of random harmonic coefficients.
1416
@@ -26,6 +28,8 @@ def generate_flm(
2628
2729
reality (bool, optional): Reality of signal. Defaults to False.
2830
31+
using_torch (bool, optional): Desired frontend functionality. Defaults to False.
32+
2933
Returns:
3034
np.ndarray: Random set of spherical harmonic coefficients.
3135
@@ -45,7 +49,7 @@ def generate_flm(
4549
else:
4650
flm[el, -m + L - 1] = rng.uniform() + 1j * rng.uniform()
4751

48-
return flm
52+
return torch.from_numpy(flm) if using_torch else flm
4953

5054

5155
def generate_flmn(
@@ -54,6 +58,7 @@ def generate_flmn(
5458
N: int = 1,
5559
L_lower: int = 0,
5660
reality: bool = False,
61+
using_torch: bool = False,
5762
) -> np.ndarray:
5863
r"""Generate a 3D set of random Wigner coefficients.
5964
Note:
@@ -70,6 +75,8 @@ def generate_flmn(
7075
7176
reality (bool, optional): Reality of signal. Defaults to False.
7277
78+
using_torch (bool, optional): Desired frontend functionality. Defaults to False.
79+
7380
Returns:
7481
7582
np.ndarray: Random set of Wigner coefficients.
@@ -97,4 +104,4 @@ def generate_flmn(
97104
else:
98105
flmn[N - 1 + n, el, -m + L - 1] = rng.uniform() + 1j * rng.uniform()
99106

100-
return flmn
107+
return torch.from_numpy(flmn) if using_torch else flmn

0 commit comments

Comments
 (0)