Skip to content

Commit ec20bff

Browse files
committed
add rotation function (risbo) and update docs
1 parent f545458 commit ec20bff

File tree

6 files changed

+181
-1
lines changed

6 files changed

+181
-1
lines changed

docs/api/utility/augmentation.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:html_theme.sidebar_secondary.remove:
2+
3+
**************************
4+
augmentation
5+
**************************
6+
.. automodule:: s2fft.utils.augmentation
7+
:members:

docs/api/utility/index.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,22 @@ Utility Functions
100100
* - :func:`~s2fft.utils.signal_generator.generate_flmn`
101101
- Generate a 3D set of random Wigner coefficients.
102102

103-
104103
.. note::
105104

106105
JAX versions of these functions share an almost identical function trace and
107106
are simply accessed by the sub-module :func:`~s2fft.utils.resampling_jax`.
108107

108+
.. list-table:: Augmentation functions
109+
:widths: 25 25
110+
:header-rows: 1
111+
112+
* - Function Name
113+
- Description
114+
* - :func:`~s2fft.utils.augmentation.rotate_flms`
115+
- Euler rotates spherical harmonic coefficients by given angle in zyz convention.
116+
* - :func:`~s2fft.utils.augmentation.generate_rotate_dls`
117+
- Generates an array of all reduced Wigner d-function coefficients for angle beta.
118+
109119
.. toctree::
110120
:hidden:
111121
:maxdepth: 2
@@ -118,5 +128,6 @@ Utility Functions
118128
quadrature_jax
119129
healpix_ffts
120130
utils
131+
augmentation
121132
logs
122133

s2fft/recursions/risbo_jax.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,25 @@
55

66
@partial(jit, static_argnums=(1, 2, 3))
77
def compute_full(dl: jnp.ndarray, beta: float, L: int, el: int) -> jnp.ndarray:
8+
r"""Compute Wigner-d at argument :math:`\beta` for full plane using
9+
Risbo recursion (JAX implementation)
10+
11+
The Wigner-d plane is computed by recursion over :math:`\ell` (`el`).
12+
Thus, for :math:`\ell > 0` the plane must be computed already for
13+
:math:`\ell - 1`. At present, for :math:`\ell = 0` the recusion is initialised.
14+
15+
Args:
16+
dl (np.ndarray): Wigner-d plane for :math:`\ell - 1` at :math:`\beta`.
17+
18+
beta (float): Argument :math:`\beta` at which to compute Wigner-d plane.
19+
20+
L (int): Harmonic band-limit.
21+
22+
el (int): Spherical harmonic degree :math:`\ell`.
23+
24+
Returns:
25+
np.ndarray: Plane of Wigner-d for `el` and `beta`, with full plane computed.
26+
"""
827
if el == 0:
928
dl = dl.at[el + L - 1, el + L - 1].set(1.0)
1029
return dl

s2fft/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from . import resampling_jax
55
from . import healpix_ffts
66
from . import signal_generator
7+
from . import augmentation

s2fft/utils/augmentation.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import jax.numpy as jnp
2+
from jax import jit
3+
from functools import partial
4+
from typing import Tuple
5+
6+
from s2fft.recursions.risbo_jax import compute_full
7+
8+
9+
@partial(jit, static_argnums=(1, 2))
10+
def rotate_flms(
11+
flm: jnp.ndarray,
12+
L: int,
13+
rotation: Tuple[float, float, float],
14+
dl_array: jnp.ndarray = None,
15+
) -> jnp.ndarray:
16+
"""Rotates an array of spherical harmonic coefficients by angle rotation.
17+
18+
Args:
19+
flm (jnp.ndarray): Array of spherical harmonic coefficients.
20+
L (int): Harmonic band-limit.
21+
rotation (Tuple[float, float, float]): Rotation on the sphere (alpha, beta, gamma).
22+
dl_array (jnp.ndarray, optional): Precomputed array of reduced Wigner d-function
23+
coefficients, see :func:~`generate_rotate_dls`. Defaults to None.
24+
25+
Returns:
26+
jnp.ndarray: Rotated spherical harmonic coefficients with shape [L,2L-1].
27+
"""
28+
29+
# Split out angles
30+
alpha = __exp_array(L, rotation[0])
31+
gamma = __exp_array(L, rotation[2])
32+
beta = rotation[1]
33+
34+
# Create empty arrays
35+
flm_rotated = jnp.zeros_like(flm)
36+
37+
dl = (
38+
dl_array
39+
if dl_array != None
40+
else jnp.zeros((2 * L - 1, 2 * L - 1)).astype(jnp.complex128)
41+
)
42+
43+
# Perform rotation
44+
for el in range(L):
45+
if dl_array is None:
46+
dl = compute_full(dl, beta, L, el)
47+
n_max = min(el, L - 1)
48+
49+
m = jnp.arange(-el, el + 1)
50+
n = jnp.arange(-n_max, n_max + 1)
51+
52+
flm_rotated = flm_rotated.at[el, L - 1 + m].add(
53+
jnp.einsum(
54+
"mn,n->m",
55+
jnp.einsum(
56+
"mn,m->mn",
57+
dl[m + L - 1][:, n + L - 1]
58+
if dl_array is None
59+
else dl[el, m + L - 1][:, n + L - 1],
60+
alpha[m + L - 1],
61+
optimize=True,
62+
),
63+
gamma[n + L - 1] * flm[el, n + L - 1],
64+
)
65+
)
66+
return flm_rotated
67+
68+
69+
@partial(jit, static_argnums=(0, 1))
70+
def __exp_array(L: int, x: float) -> jnp.ndarray:
71+
"""Private function to generate rotation arrays for alpha/gamma rotations"""
72+
return jnp.exp(-1j * jnp.arange(-L + 1, L) * x)
73+
74+
75+
@partial(jit, static_argnums=(0, 1))
76+
def generate_rotate_dls(L: int, beta: float) -> jnp.ndarray:
77+
"""Function which recursively generates the complete plane of reduced
78+
Wigner d-function coefficients at a given rotation beta.
79+
80+
Args:
81+
L (int): Harmonic band-limit.
82+
beta (float): Rotation on the sphere.
83+
84+
Returns:
85+
jnp.ndarray: Complete array of [L, 2L-1,2L-1] Wigner d-function coefficients
86+
for a fixed rotation beta.
87+
"""
88+
dl = jnp.zeros((L, 2 * L - 1, 2 * L - 1)).astype(jnp.float64)
89+
dl_iter = jnp.zeros((2 * L - 1, 2 * L - 1)).astype(jnp.float64)
90+
for el in range(L):
91+
dl_iter = compute_full(dl_iter, beta, L, el)
92+
dl = dl.at[el].add(dl_iter)
93+
return dl

tests/test_utils.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
1+
from jax.config import config
2+
3+
config.update("jax_enable_x64", True)
14
import pytest
5+
import pyssht as ssht
26
import numpy as np
37
from s2fft.sampling import s2_samples as samples
8+
from s2fft.utils.augmentation import rotate_flms, generate_rotate_dls
9+
10+
L_to_test = [6, 8, 10]
11+
angles_to_test = [np.pi / 2, np.pi / 6]
412

513

614
def test_flm_reindexing_functions(flm_generator):
@@ -47,3 +55,44 @@ def test_flm_reindexing_exceptions(flm_generator):
4755

4856
with pytest.raises(ValueError) as e:
4957
samples.flm_1d_to_2d(flm_3d, L)
58+
59+
60+
@pytest.mark.parametrize("L", L_to_test)
61+
@pytest.mark.parametrize("alpha", angles_to_test)
62+
@pytest.mark.parametrize("beta", angles_to_test)
63+
@pytest.mark.parametrize("gamma", angles_to_test)
64+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
65+
@pytest.mark.filterwarnings("ignore::FutureWarning")
66+
def test_rotate_flms(flm_generator, L: int, alpha: float, beta: float, gamma: float):
67+
flm = flm_generator(L=L)
68+
rot = (alpha, beta, gamma)
69+
flm_1d = samples.flm_2d_to_1d(flm, L)
70+
71+
flm_rot_ssht = samples.flm_1d_to_2d(
72+
ssht.rotate_flms(flm_1d, alpha, beta, gamma, L), L
73+
)
74+
flm_rot_s2fft = rotate_flms(flm, L, rot)
75+
76+
np.testing.assert_allclose(flm_rot_ssht, flm_rot_s2fft, atol=1e-14)
77+
78+
79+
@pytest.mark.parametrize("L", L_to_test)
80+
@pytest.mark.parametrize("alpha", angles_to_test)
81+
@pytest.mark.parametrize("beta", angles_to_test)
82+
@pytest.mark.parametrize("gamma", angles_to_test)
83+
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
84+
@pytest.mark.filterwarnings("ignore::FutureWarning")
85+
def test_rotate_flms_precompute_dls(
86+
flm_generator, L: int, alpha: float, beta: float, gamma: float
87+
):
88+
dl = generate_rotate_dls(L, beta)
89+
flm = flm_generator(L=L)
90+
rot = (alpha, beta, gamma)
91+
flm_1d = samples.flm_2d_to_1d(flm, L)
92+
93+
flm_rot_ssht = samples.flm_1d_to_2d(
94+
ssht.rotate_flms(flm_1d, alpha, beta, gamma, L), L
95+
)
96+
flm_rot_s2fft = rotate_flms(flm, L, rot, dl)
97+
98+
np.testing.assert_allclose(flm_rot_ssht, flm_rot_s2fft, atol=1e-14)

0 commit comments

Comments
 (0)