Skip to content

Commit f545458

Browse files
committed
add vectorized jax version of risbo wigner-d recursion
1 parent 2d2370f commit f545458

File tree

6 files changed

+126
-3
lines changed

6 files changed

+126
-3
lines changed

docs/api/recursions/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ Wigner-d recursions
8989
- Description
9090
* - :func:`~s2fft.recursions.risbo.compute_full`
9191
- Compute Wigner-d at argument :math:`\beta` for full plane using Risbo recursion.
92+
* - :func:`~s2fft.recursions.risbo_jax.compute_full`
93+
- Compute Wigner-d at argument :math:`\beta` for full plane using Risbo recursion (JAX implementation).
9294

9395
.. warning::
9496

docs/api/recursions/risbo_jax.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
:html_theme.sidebar_secondary.remove:
2+
3+
**************************
4+
Risbo JAX
5+
**************************
6+
.. automodule:: s2fft.recursions.risbo_jax
7+
:members:

s2fft/recursions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from . import trapani
22
from . import risbo
3+
from . import risbo_jax
34
from . import turok
45
from . import turok_jax
56
from . import price_mcewen

s2fft/recursions/risbo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
5656
# from l - 1 to l - 1/2.
5757
dd = np.zeros((2 * el + 2, 2 * el + 2))
5858
j = 2 * el - 1
59-
rj = float(j) # TODO: is this necessary?
59+
6060
for k in range(0, j):
6161
sqrt_jmk = np.sqrt(j - k)
6262
sqrt_kp1 = np.sqrt(k + 1)
@@ -77,7 +77,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
7777
# the plane of the dl-matrix to 0.0.
7878
dl[-el + L - 1 : el + 1 + L - 1, -el + L - 1 : el + 1 + L - 1] = 0.0
7979
j = 2 * el
80-
rj = float(j) # TODO: is this necessary?
80+
8181
for k in range(0, j):
8282
sqrt_jmk = np.sqrt(j - k)
8383
sqrt_kp1 = np.sqrt(k + 1)

s2fft/recursions/risbo_jax.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import jax.numpy as jnp
2+
from jax import jit
3+
from functools import partial
4+
5+
6+
@partial(jit, static_argnums=(1, 2, 3))
7+
def compute_full(dl: jnp.ndarray, beta: float, L: int, el: int) -> jnp.ndarray:
8+
if el == 0:
9+
dl = dl.at[el + L - 1, el + L - 1].set(1.0)
10+
return dl
11+
elif el == 1:
12+
cosb = jnp.cos(beta)
13+
sinb = jnp.sin(beta)
14+
15+
coshb = jnp.cos(beta / 2.0)
16+
sinhb = jnp.sin(beta / 2.0)
17+
sqrt2 = jnp.sqrt(2.0)
18+
19+
dl = dl.at[L - 2, L - 2].set(coshb**2)
20+
dl = dl.at[L - 2, L - 1].set(sinb / sqrt2)
21+
dl = dl.at[L - 2, L].set(sinhb**2)
22+
23+
dl = dl.at[L - 1, L - 2].set(-sinb / sqrt2)
24+
dl = dl.at[L - 1, L - 1].set(cosb)
25+
dl = dl.at[L - 1, L].set(sinb / sqrt2)
26+
27+
dl = dl.at[L, L - 2].set(sinhb**2)
28+
dl = dl.at[L, L - 1].set(-sinb / sqrt2)
29+
dl = dl.at[L, L].set(coshb**2)
30+
return dl
31+
else:
32+
coshb = -jnp.cos(beta / 2.0)
33+
sinhb = jnp.sin(beta / 2.0)
34+
dd = jnp.zeros((2 * el + 2, 2 * el + 2))
35+
36+
# First pass
37+
j = 2 * el - 1
38+
i = jnp.arange(j)
39+
k = jnp.arange(j)
40+
41+
sqrt_jmk = jnp.sqrt(j - k)
42+
sqrt_kp1 = jnp.sqrt(k + 1)
43+
sqrt_jmi = jnp.sqrt(j - i)
44+
sqrt_ip1 = jnp.sqrt(i + 1)
45+
46+
dlj = dl[k - (el - 1) + L - 1][:, i - (el - 1) + L - 1]
47+
48+
dd = dd.at[:j, :j].add(
49+
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_jmk, optimize=True) * dlj * coshb
50+
)
51+
dd = dd.at[:j, 1 : j + 1].add(
52+
jnp.einsum("i,k->ki", -sqrt_ip1, sqrt_jmk, optimize=True) * dlj * sinhb
53+
)
54+
dd = dd.at[1 : j + 1, :j].add(
55+
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_kp1, optimize=True) * dlj * sinhb
56+
)
57+
dd = dd.at[1 : j + 1, 1 : j + 1].add(
58+
jnp.einsum("i,k->ki", sqrt_ip1, sqrt_kp1, optimize=True) * dlj * coshb
59+
)
60+
61+
dl = dl.at[-el + L - 1 : el + 1 + L - 1, -el + L - 1 : el + 1 + L - 1].multiply(
62+
0.0
63+
)
64+
65+
j = 2 * el
66+
i = jnp.arange(j)
67+
k = jnp.arange(j)
68+
69+
# Second pass
70+
sqrt_jmk = jnp.sqrt(j - k)
71+
sqrt_kp1 = jnp.sqrt(k + 1)
72+
sqrt_jmi = jnp.sqrt(j - i)
73+
sqrt_ip1 = jnp.sqrt(i + 1)
74+
75+
dl = dl.at[-el + L - 1 : el + L - 1, -el + L - 1 : el + L - 1].add(
76+
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_jmk, optimize=True)
77+
* dd[:j, :j]
78+
* coshb,
79+
)
80+
dl = dl.at[-el + L - 1 : el + L - 1, L - el : L + el].add(
81+
jnp.einsum("i,k->ki", -sqrt_ip1, sqrt_jmk, optimize=True)
82+
* dd[:j, :j]
83+
* sinhb,
84+
)
85+
dl = dl.at[L - el : L + el, -el + L - 1 : el + L - 1].add(
86+
jnp.einsum("i,k->ki", sqrt_jmi, sqrt_kp1, optimize=True)
87+
* dd[:j, :j]
88+
* sinhb,
89+
)
90+
dl = dl.at[L - el : L + el, L - el : L + el].add(
91+
jnp.einsum("i,k->ki", sqrt_ip1, sqrt_kp1, optimize=True)
92+
* dd[:j, :j]
93+
* coshb,
94+
)
95+
return dl / ((2 * el) * (2 * el - 1))

tests/test_wigner_recursions.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,30 @@ def test_risbo_with_ssht():
158158

159159
# Compare to routines in SSHT, which have been validated extensively.
160160
dl = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
161-
# dl = recursions.trapani.init(dl, L)
161+
162162
for el in range(0, L):
163163
dl = recursions.risbo.compute_full(dl, beta, L, el)
164164
np.testing.assert_allclose(dl_array[el, :, :], dl, atol=1e-15)
165165

166166

167+
def test_risbo_with_ssht_jax():
168+
"""Test Risbo JAX computation against ssht"""
169+
170+
# Test all dl(pi/2) terms up to L.
171+
L = 10
172+
173+
# Compute using SSHT.
174+
beta = np.pi / 2.0
175+
dl_array = ssht.generate_dl(beta, L)
176+
177+
# Compare to routines in SSHT, which have been validated extensively.
178+
dl = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
179+
180+
for el in range(0, L):
181+
dl = recursions.risbo_jax.compute_full(dl, beta, L, el)
182+
np.testing.assert_allclose(dl_array[el, :, :], dl, atol=1e-15)
183+
184+
167185
@pytest.mark.parametrize("L", L_to_test)
168186
@pytest.mark.parametrize("sampling", ["mw", "mwss", "dh", "healpix"])
169187
def test_turok_with_ssht(L: int, sampling: str):

0 commit comments

Comments
 (0)