Skip to content

Commit 6444bb0

Browse files
committed
Update README and notebook to indicate wider torch support
1 parent 034ecbd commit 6444bb0

File tree

2 files changed

+69
-13
lines changed

2 files changed

+69
-13
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ f = fft.wigner.inverse_jax(flmn, L, N, method="jax")
195195
For further details on usage see the [documentation](https://astro-informatics.github.io/s2fft/) and associated [notebooks](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/spherical_harmonic_transform.html).
196196

197197
> [!NOTE]
198-
> We also provide PyTorch support for the precompute version of our transforms, as demonstrated in the [_Torch frontend_ tutorial notebook](https://astro-informatics.github.io/s2fft/tutorials/torch_frontend/torch_frontend.html).
198+
> We also provide PyTorch support for our transforms, as demonstrated in the [_Torch frontend_ tutorial notebook](https://astro-informatics.github.io/s2fft/tutorials/torch_frontend/torch_frontend.html). This wraps the JAX implementations so JAX will need to be installed in addition to PyTorch.
199199
200200
## SSHT & HEALPix wrappers 💡
201201

notebooks/torch_frontend.ipynb

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,11 @@
4040
"import jax\n",
4141
"jax.config.update(\"jax_enable_x64\", True)\n",
4242
"import torch \n",
43-
"import numpy as np \n",
44-
"from s2fft.precompute_transforms.spherical import inverse, forward\n",
43+
"import numpy as np\n",
44+
"from s2fft.transforms.spherical import inverse, forward\n",
45+
"from s2fft.precompute_transforms.spherical import (\n",
46+
" inverse as precompute_inverse, forward as precompute_forward\n",
47+
")\n",
4548
"from s2fft.precompute_transforms.construct import spin_spherical_kernel_torch\n",
4649
"from s2fft.utils import signal_generator"
4750
]
@@ -68,7 +71,7 @@
6871
"cell_type": "markdown",
6972
"metadata": {},
7073
"source": [
71-
"For the fully precompute transform we must also generate the precompute kernels which we store as a torch tensors."
74+
"Now lets calculate the signal on the sphere by applying the inverse spherical harmonic transform"
7275
]
7376
},
7477
{
@@ -85,15 +88,14 @@
8588
}
8689
],
8790
"source": [
88-
"inverse_kernel = spin_spherical_kernel_torch(L, forward=False) \n",
89-
"forward_kernel = spin_spherical_kernel_torch(L, forward=True) "
91+
"f = inverse(flm, L, method=\"torch\")"
9092
]
9193
},
9294
{
9395
"cell_type": "markdown",
9496
"metadata": {},
9597
"source": [
96-
"Now lets calculate the signal on the sphere by applying the inverse spherical harmonic transform"
98+
"To calculate the corresponding spherical harmonic representation execute"
9799
]
98100
},
99101
{
@@ -102,36 +104,90 @@
102104
"metadata": {},
103105
"outputs": [],
104106
"source": [
105-
"f = inverse(flm, L, 0, inverse_kernel, method=\"torch\")"
107+
"flm_check = forward(f, L, method=\"torch\")"
106108
]
107109
},
108110
{
109111
"cell_type": "markdown",
110112
"metadata": {},
111113
"source": [
112-
"To calculate the corresponding spherical harmonic representation execute"
114+
"Finally, lets check the error on the round trip is as expected for 64 bit machine precision floating point arithmetic"
113115
]
114116
},
115117
{
116118
"cell_type": "code",
117119
"execution_count": 6,
118120
"metadata": {},
119-
"outputs": [],
121+
"outputs": [
122+
{
123+
"name": "stdout",
124+
"output_type": "stream",
125+
"text": [
126+
"Mean absolute error = 2.8915048238993476e-14\n"
127+
]
128+
}
129+
],
120130
"source": [
121-
"flm_check = forward(f, L, 0, forward_kernel, method=\"torch\")"
131+
"print(f\"Mean absolute error = {np.nanmean(np.abs(flm_check - flm))}\")"
122132
]
123133
},
124134
{
125135
"cell_type": "markdown",
126136
"metadata": {},
127137
"source": [
128-
"Finally, lets check the error on the roundtrip is at 64bit machine precision"
138+
"For the fully precompute transform we must also generate the precompute kernels which we store as a torch tensors."
129139
]
130140
},
131141
{
132142
"cell_type": "code",
133143
"execution_count": 7,
134144
"metadata": {},
145+
"outputs": [],
146+
"source": [
147+
"inverse_kernel = spin_spherical_kernel_torch(L, forward=False) \n",
148+
"forward_kernel = spin_spherical_kernel_torch(L, forward=True) "
149+
]
150+
},
151+
{
152+
"cell_type": "markdown",
153+
"metadata": {},
154+
"source": [
155+
"We then pass the kernels as additional arguments to the transform functions"
156+
]
157+
},
158+
{
159+
"cell_type": "code",
160+
"execution_count": null,
161+
"metadata": {},
162+
"outputs": [
163+
{
164+
"ename": "NameError",
165+
"evalue": "name 'orward_kernel' is not defined",
166+
"output_type": "error",
167+
"traceback": [
168+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
169+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
170+
"Cell \u001b[0;32mIn[8], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m precompute_f \u001b[38;5;241m=\u001b[39m precompute_inverse(flm, L, kernel\u001b[38;5;241m=\u001b[39minverse_kernel, method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m precompute_flm_check \u001b[38;5;241m=\u001b[39m precompute_forward(f, L, kernel\u001b[38;5;241m=\u001b[39m\u001b[43morward_kernel\u001b[49m, method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
171+
"\u001b[0;31mNameError\u001b[0m: name 'orward_kernel' is not defined"
172+
]
173+
}
174+
],
175+
"source": [
176+
"precompute_f = precompute_inverse(flm, L, kernel=inverse_kernel, method=\"torch\")\n",
177+
"precompute_flm_check = precompute_forward(f, L, kernel=forward_kernel, method=\"torch\")"
178+
]
179+
},
180+
{
181+
"cell_type": "markdown",
182+
"metadata": {},
183+
"source": [
184+
"Again, we check the error on the round trip is as expected"
185+
]
186+
},
187+
{
188+
"cell_type": "code",
189+
"execution_count": null,
190+
"metadata": {},
135191
"outputs": [
136192
{
137193
"name": "stdout",
@@ -142,7 +198,7 @@
142198
}
143199
],
144200
"source": [
145-
"print(f\"Mean absolute error = {np.nanmean(np.abs(flm_check - flm))}\")"
201+
"print(f\"Mean absolute error = {np.nanmean(np.abs(precompute_flm_check - flm))}\")"
146202
]
147203
}
148204
],

0 commit comments

Comments
 (0)