Skip to content

Commit 6eb14df

Browse files
committed
rename augmentation to rotation, add notebook, update docs
1 parent c2bf786 commit 6eb14df

File tree

10 files changed

+235
-21
lines changed

10 files changed

+235
-21
lines changed

docs/api/utility/index.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,15 @@ Utility Functions
105105
JAX versions of these functions share an almost identical function trace and
106106
are simply accessed by the sub-module :func:`~s2fft.utils.resampling_jax`.
107107

108-
.. list-table:: Augmentation functions
108+
.. list-table:: Rotation functions
109109
:widths: 25 25
110110
:header-rows: 1
111111

112112
* - Function Name
113113
- Description
114-
* - :func:`~s2fft.utils.augmentation.rotate_flms`
114+
* - :func:`~s2fft.utils.rotation.rotate_flms`
115115
- Euler rotates spherical harmonic coefficients by given angle in zyz convention.
116-
* - :func:`~s2fft.utils.augmentation.generate_rotate_dls`
116+
* - :func:`~s2fft.utils.rotation.generate_rotate_dls`
117117
- Generates an array of all reduced Wigner d-function coefficients for angle beta.
118118

119119
.. toctree::
@@ -128,6 +128,6 @@ Utility Functions
128128
quadrature_jax
129129
healpix_ffts
130130
utils
131-
augmentation
131+
rotation
132132
logs
133133

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
:html_theme.sidebar_secondary.remove:
22

33
**************************
4-
augmentation
4+
rotations
55
**************************
6-
.. automodule:: s2fft.utils.augmentation
6+
.. automodule:: s2fft.utils.rotation
77
:members:

docs/tutorials/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,4 @@ the case for many other methods that scale linearly with spin).
6767

6868
spherical_harmonic/spherical_harmonic_transform.nblink
6969
wigner/wigner_transform.nblink
70+
rotation/rotation.nblink
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"path": "../../../notebooks/spherical_rotation.ipynb"
3+
}

notebooks/spherical_rotation.ipynb

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# __Rotate a signal__\n",
8+
"\n",
9+
"---\n",
10+
"\n"
11+
]
12+
},
13+
{
14+
"attachments": {},
15+
"cell_type": "markdown",
16+
"metadata": {},
17+
"source": [
18+
"This tutorial demonstrates how to use `S2FFT` to rotate a signal on the sphere. A signal can be rotated in pixel space, however this can introduce artifacts. The best way to perform a rotation is through spherical harmonic space using the Wigner d-matrices, that is:\n",
19+
"\n",
20+
"- forward spherical harmonic transform\n",
21+
"- rotation on the flm coefficients\n",
22+
"- inverse spherical harmonic transform\n",
23+
"\n",
24+
"Specifically, we will adopt the sampling scheme of [McEwen & Wiaux (2012)](https://arxiv.org/abs/1110.6298). For our purposes here we'll just generate a random bandlimited signal."
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": 5,
30+
"metadata": {},
31+
"outputs": [],
32+
"source": [
33+
"from jax.config import config\n",
34+
"config.update(\"jax_enable_x64\", True)\n",
35+
"\n",
36+
"import numpy as np\n",
37+
"import s2fft \n",
38+
"import plotting_functions\n",
39+
"\n",
40+
"L = 128\n",
41+
"sampling = \"mw\"\n",
42+
"rng = np.random.default_rng(12346161)\n",
43+
"flm = s2fft.utils.signal_generator.generate_flm(rng, L)\n",
44+
"f = s2fft.inverse(flm, L)"
45+
]
46+
},
47+
{
48+
"attachments": {},
49+
"cell_type": "markdown",
50+
"metadata": {},
51+
"source": [
52+
"### Execute the rotation steps\n",
53+
"\n",
54+
"---\n",
55+
"\n",
56+
"First, we will run the JAX function to compute the spherical harmonic transform of our signal"
57+
]
58+
},
59+
{
60+
"cell_type": "code",
61+
"execution_count": 7,
62+
"metadata": {},
63+
"outputs": [],
64+
"source": [
65+
"flm = s2fft.forward_jax(f, L, reality=True)"
66+
]
67+
},
68+
{
69+
"cell_type": "markdown",
70+
"metadata": {},
71+
"source": [
72+
"Now apply the rotation (here pi/2 in each of alpha, beta, gamma) on the harmonic coefficients flm "
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": 8,
78+
"metadata": {},
79+
"outputs": [],
80+
"source": [
81+
"flm_rotated = s2fft.rotate_flms(flm, L, (np.pi/2, np.pi/2,np.pi/2))"
82+
]
83+
},
84+
{
85+
"attachments": {},
86+
"cell_type": "markdown",
87+
"metadata": {},
88+
"source": [
89+
"Finally, we will run the JAX function to compute the inverse spherical harmonic transform to get back to pixel space"
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": 9,
95+
"metadata": {},
96+
"outputs": [],
97+
"source": [
98+
"f_rotated = s2fft.inverse_jax(flm_rotated, L, reality=True)"
99+
]
100+
}
101+
],
102+
"metadata": {
103+
"kernelspec": {
104+
"display_name": "Python 3.10.4 ('s2fft')",
105+
"language": "python",
106+
"name": "python3"
107+
},
108+
"language_info": {
109+
"codemirror_mode": {
110+
"name": "ipython",
111+
"version": 3
112+
},
113+
"file_extension": ".py",
114+
"mimetype": "text/x-python",
115+
"name": "python",
116+
"nbconvert_exporter": "python",
117+
"pygments_lexer": "ipython3",
118+
"version": "3.10.4"
119+
},
120+
"orig_nbformat": 4,
121+
"vscode": {
122+
"interpreter": {
123+
"hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
124+
}
125+
}
126+
},
127+
"nbformat": 4,
128+
"nbformat_minor": 2
129+
}

s2fft/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
generate_precomputes_wigner,
88
generate_precomputes_wigner_jax,
99
)
10+
from .utils.rotation import rotate_flms, generate_rotate_dls
1011

1112
import logging
1213
from jax.config import config

s2fft/recursions/risbo_jax.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,9 @@ def compute_full(dl: jnp.ndarray, beta: float, L: int, el: int) -> jnp.ndarray:
2525
np.ndarray: Plane of Wigner-d for `el` and `beta`, with full plane computed.
2626
"""
2727

28-
if beta == 0:
29-
dl = dl.at[-el + L - 1, el + L - 1].set(1.0)
30-
3128
if el == 0:
3229
dl = dl.at[el + L - 1, el + L - 1].set(1.0)
3330
return dl
34-
35-
if beta == jnp.pi:
36-
dl = dl.at[:, :].multiply(0)
37-
ind = jnp.arange(-el + L - 1, el + L)
38-
dl = jnp.flip(dl.at[ind, ind].add(1.0), axis=0)
39-
dl = jnp.einsum(
40-
"nm,m->nm", dl, (-1) ** (el + jnp.arange(-L + 1, L)), optimize=True
41-
)
42-
return dl
43-
4431
if el == 1:
4532
cosb = jnp.cos(beta)
4633
sinb = jnp.sin(beta)

s2fft/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
from . import resampling_jax
55
from . import healpix_ffts
66
from . import signal_generator
7-
from . import augmentation
7+
from . import rotation

s2fft/utils/rotation.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import jax.numpy as jnp
2+
from jax import jit
3+
from functools import partial
4+
from typing import Tuple
5+
6+
from s2fft.recursions.risbo_jax import compute_full
7+
8+
9+
@partial(jit, static_argnums=(1, 2))
10+
def rotate_flms(
11+
flm: jnp.ndarray,
12+
L: int,
13+
rotation: Tuple[float, float, float],
14+
dl_array: jnp.ndarray = None,
15+
) -> jnp.ndarray:
16+
"""Rotates an array of spherical harmonic coefficients by angle rotation.
17+
18+
Args:
19+
flm (jnp.ndarray): Array of spherical harmonic coefficients.
20+
L (int): Harmonic band-limit.
21+
rotation (Tuple[float, float, float]): Rotation on the sphere (alpha, beta, gamma).
22+
dl_array (jnp.ndarray, optional): Precomputed array of reduced Wigner d-function
23+
coefficients, see :func:~`generate_rotate_dls`. Defaults to None.
24+
25+
Returns:
26+
jnp.ndarray: Rotated spherical harmonic coefficients with shape [L,2L-1].
27+
"""
28+
29+
# Split out angles
30+
alpha = __exp_array(L, rotation[0])
31+
gamma = __exp_array(L, rotation[2])
32+
beta = rotation[1]
33+
34+
# Create empty arrays
35+
flm_rotated = jnp.zeros_like(flm)
36+
37+
dl = (
38+
dl_array
39+
if dl_array != None
40+
else jnp.zeros((2 * L - 1, 2 * L - 1)).astype(jnp.complex128)
41+
)
42+
43+
# Perform rotation
44+
for el in range(L):
45+
if dl_array is None:
46+
dl = compute_full(dl, beta, L, el)
47+
n_max = min(el, L - 1)
48+
49+
m = jnp.arange(-el, el + 1)
50+
n = jnp.arange(-n_max, n_max + 1)
51+
52+
flm_rotated = flm_rotated.at[el, L - 1 + m].add(
53+
jnp.einsum(
54+
"mn,n->m",
55+
jnp.einsum(
56+
"mn,m->mn",
57+
dl[m + L - 1][:, n + L - 1]
58+
if dl_array is None
59+
else dl[el, m + L - 1][:, n + L - 1],
60+
alpha[m + L - 1],
61+
optimize=True,
62+
),
63+
gamma[n + L - 1] * flm[el, n + L - 1],
64+
)
65+
)
66+
return flm_rotated
67+
68+
69+
@partial(jit, static_argnums=(0, 1))
70+
def __exp_array(L: int, x: float) -> jnp.ndarray:
71+
"""Private function to generate rotation arrays for alpha/gamma rotations"""
72+
return jnp.exp(-1j * jnp.arange(-L + 1, L) * x)
73+
74+
75+
@partial(jit, static_argnums=(0, 1))
76+
def generate_rotate_dls(L: int, beta: float) -> jnp.ndarray:
77+
"""Function which recursively generates the complete plane of reduced
78+
Wigner d-function coefficients at a given rotation beta.
79+
80+
Args:
81+
L (int): Harmonic band-limit.
82+
beta (float): Rotation on the sphere.
83+
84+
Returns:
85+
jnp.ndarray: Complete array of [L, 2L-1,2L-1] Wigner d-function coefficients
86+
for a fixed rotation beta.
87+
"""
88+
dl = jnp.zeros((L, 2 * L - 1, 2 * L - 1)).astype(jnp.float64)
89+
dl_iter = jnp.zeros((2 * L - 1, 2 * L - 1)).astype(jnp.float64)
90+
for el in range(L):
91+
dl_iter = compute_full(dl_iter, beta, L, el)
92+
dl = dl.at[el].add(dl_iter)
93+
return dl

tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pyssht as ssht
66
import numpy as np
77
from s2fft.sampling import s2_samples as samples
8-
from s2fft.utils.augmentation import rotate_flms, generate_rotate_dls
8+
from s2fft.utils.rotation import rotate_flms, generate_rotate_dls
99
import jax.numpy as jnp
1010
from jax.test_util import check_grads
1111

0 commit comments

Comments
 (0)