Skip to content

Commit 26ff37a

Browse files
authored
Merge pull request #182 from astro-informatics/feature/rotations
Feature/rotations
2 parents 2d2370f + cfb3861 commit 26ff37a

File tree

15 files changed

+466
-14
lines changed

15 files changed

+466
-14
lines changed

docs/api/recursions/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ Wigner-d recursions
8989
- Description
9090
* - :func:`~s2fft.recursions.risbo.compute_full`
9191
- Compute Wigner-d at argument :math:`\beta` for full plane using Risbo recursion.
92+
* - :func:`~s2fft.recursions.risbo_jax.compute_full`
93+
- Compute Wigner-d at argument :math:`\beta` for full plane using Risbo recursion (JAX implementation).
9294

9395
.. warning::
9496

docs/api/recursions/risbo_jax.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:html_theme.sidebar_secondary.remove:
2+
3+
**************************
4+
Risbo JAX
5+
**************************
6+
.. automodule:: s2fft.recursions.risbo_jax
7+
:members:

docs/api/utility/index.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,22 @@ Utility Functions
100100
* - :func:`~s2fft.utils.signal_generator.generate_flmn`
101101
- Generate a 3D set of random Wigner coefficients.
102102

103-
104103
.. note::
105104

106105
JAX versions of these functions share an almost identical function trace and
107106
are simply accessed by the sub-module :func:`~s2fft.utils.resampling_jax`.
108107

108+
.. list-table:: Rotation functions
109+
:widths: 25 25
110+
:header-rows: 1
111+
112+
* - Function Name
113+
- Description
114+
* - :func:`~s2fft.utils.rotation.rotate_flms`
115+
- Euler rotates spherical harmonic coefficients by given angle in zyz convention.
116+
* - :func:`~s2fft.utils.rotation.generate_rotate_dls`
117+
- Generates an array of all reduced Wigner d-function coefficients for angle beta.
118+
109119
.. toctree::
110120
:hidden:
111121
:maxdepth: 2
@@ -118,5 +128,6 @@ Utility Functions
118128
quadrature_jax
119129
healpix_ffts
120130
utils
131+
rotation
121132
logs
122133

docs/api/utility/rotation.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:html_theme.sidebar_secondary.remove:
2+
3+
**************************
4+
rotations
5+
**************************
6+
.. automodule:: s2fft.utils.rotation
7+
:members:

docs/tutorials/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,4 @@ the case for many other methods that scale linearly with spin).
6767

6868
spherical_harmonic/spherical_harmonic_transform.nblink
6969
wigner/wigner_transform.nblink
70+
rotation/rotation.nblink
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"path": "../../../notebooks/spherical_rotation.ipynb"
3+
}

notebooks/spherical_rotation.ipynb

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# __Rotate a signal__\n",
8+
"\n",
9+
"---\n",
10+
"\n"
11+
]
12+
},
13+
{
14+
"attachments": {},
15+
"cell_type": "markdown",
16+
"metadata": {},
17+
"source": [
18+
"This tutorial demonstrates how to use `S2FFT` to rotate a signal on the sphere. A signal can be rotated in pixel space, however this can introduce artifacts. The best way to perform a rotation is through spherical harmonic space using the Wigner d-matrices, that is:\n",
19+
"\n",
20+
"- forward spherical harmonic transform\n",
21+
"- rotation on the flm coefficients\n",
22+
"- inverse spherical harmonic transform\n",
23+
"\n",
24+
"Specifically, we will adopt the sampling scheme of [McEwen & Wiaux (2012)](https://arxiv.org/abs/1110.6298). For our purposes here we'll just generate a random bandlimited signal."
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": 5,
30+
"metadata": {},
31+
"outputs": [],
32+
"source": [
33+
"from jax.config import config\n",
34+
"config.update(\"jax_enable_x64\", True)\n",
35+
"\n",
36+
"import numpy as np\n",
37+
"import s2fft \n",
38+
"import plotting_functions\n",
39+
"\n",
40+
"L = 128\n",
41+
"sampling = \"mw\"\n",
42+
"rng = np.random.default_rng(12346161)\n",
43+
"flm = s2fft.utils.signal_generator.generate_flm(rng, L)\n",
44+
"f = s2fft.inverse(flm, L)"
45+
]
46+
},
47+
{
48+
"attachments": {},
49+
"cell_type": "markdown",
50+
"metadata": {},
51+
"source": [
52+
"### Execute the rotation steps\n",
53+
"\n",
54+
"---\n",
55+
"\n",
56+
"First, we will run the JAX function to compute the spherical harmonic transform of our signal"
57+
]
58+
},
59+
{
60+
"cell_type": "code",
61+
"execution_count": 7,
62+
"metadata": {},
63+
"outputs": [],
64+
"source": [
65+
"flm = s2fft.forward_jax(f, L, reality=True)"
66+
]
67+
},
68+
{
69+
"cell_type": "markdown",
70+
"metadata": {},
71+
"source": [
72+
"Now apply the rotation (here pi/2 in each of alpha, beta, gamma) on the harmonic coefficients flm "
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": 8,
78+
"metadata": {},
79+
"outputs": [],
80+
"source": [
81+
"flm_rotated = s2fft.rotate_flms(flm, L, (np.pi/2, np.pi/2,np.pi/2))"
82+
]
83+
},
84+
{
85+
"attachments": {},
86+
"cell_type": "markdown",
87+
"metadata": {},
88+
"source": [
89+
"Finally, we will run the JAX function to compute the inverse spherical harmonic transform to get back to pixel space"
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": 9,
95+
"metadata": {},
96+
"outputs": [],
97+
"source": [
98+
"f_rotated = s2fft.inverse_jax(flm_rotated, L, reality=True)"
99+
]
100+
}
101+
],
102+
"metadata": {
103+
"kernelspec": {
104+
"display_name": "Python 3.10.4 ('s2fft')",
105+
"language": "python",
106+
"name": "python3"
107+
},
108+
"language_info": {
109+
"codemirror_mode": {
110+
"name": "ipython",
111+
"version": 3
112+
},
113+
"file_extension": ".py",
114+
"mimetype": "text/x-python",
115+
"name": "python",
116+
"nbconvert_exporter": "python",
117+
"pygments_lexer": "ipython3",
118+
"version": "3.10.4"
119+
},
120+
"orig_nbformat": 4,
121+
"vscode": {
122+
"interpreter": {
123+
"hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
124+
}
125+
}
126+
},
127+
"nbformat": 4,
128+
"nbformat_minor": 2
129+
}

s2fft/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
generate_precomputes_wigner,
88
generate_precomputes_wigner_jax,
99
)
10+
from .utils.rotation import rotate_flms, generate_rotate_dls
1011

1112
import logging
1213
from jax.config import config

s2fft/recursions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from . import trapani
22
from . import risbo
3+
from . import risbo_jax
34
from . import turok
45
from . import turok_jax
56
from . import price_mcewen

s2fft/recursions/risbo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
5656
# from l - 1 to l - 1/2.
5757
dd = np.zeros((2 * el + 2, 2 * el + 2))
5858
j = 2 * el - 1
59-
rj = float(j) # TODO: is this necessary?
59+
6060
for k in range(0, j):
6161
sqrt_jmk = np.sqrt(j - k)
6262
sqrt_kp1 = np.sqrt(k + 1)
@@ -77,7 +77,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
7777
# the plane of the dl-matrix to 0.0.
7878
dl[-el + L - 1 : el + 1 + L - 1, -el + L - 1 : el + 1 + L - 1] = 0.0
7979
j = 2 * el
80-
rj = float(j) # TODO: is this necessary?
80+
8181
for k in range(0, j):
8282
sqrt_jmk = np.sqrt(j - k)
8383
sqrt_kp1 = np.sqrt(k + 1)

0 commit comments

Comments
 (0)