Skip to content

Commit c2bf786

Browse files
committed
explicitly handle singularities in risbo JAX recursion
1 parent 81891d1 commit c2bf786

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

s2fft/recursions/risbo_jax.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,24 @@ def compute_full(dl: jnp.ndarray, beta: float, L: int, el: int) -> jnp.ndarray:
2424
Returns:
2525
np.ndarray: Plane of Wigner-d for `el` and `beta`, with full plane computed.
2626
"""
27+
28+
if beta == 0:
29+
dl = dl.at[-el + L - 1, el + L - 1].set(1.0)
30+
2731
if el == 0:
2832
dl = dl.at[el + L - 1, el + L - 1].set(1.0)
2933
return dl
30-
elif el == 1:
34+
35+
if beta == jnp.pi:
36+
dl = dl.at[:, :].multiply(0)
37+
ind = jnp.arange(-el + L - 1, el + L)
38+
dl = jnp.flip(dl.at[ind, ind].add(1.0), axis=0)
39+
dl = jnp.einsum(
40+
"nm,m->nm", dl, (-1) ** (el + jnp.arange(-L + 1, L)), optimize=True
41+
)
42+
return dl
43+
44+
if el == 1:
3145
cosb = jnp.cos(beta)
3246
sinb = jnp.sin(beta)
3347

tests/test_wigner_recursions.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -136,16 +136,6 @@ def test_trapani_interfaces():
136136
recursions.trapani.compute_full(dl_jax, L, el, implementation="unexpected")
137137

138138

139-
def test_trapani_checks():
140-
# TODO
141-
142-
# Check throws exception if arguments wrong
143-
144-
# Check throws exception if don't init
145-
146-
return
147-
148-
149139
def test_risbo_with_ssht():
150140
"""Test Risbo computation against ssht"""
151141

@@ -171,15 +161,16 @@ def test_risbo_with_ssht_jax():
171161
L = 10
172162

173163
# Compute using SSHT.
174-
beta = np.pi / 2.0
175-
dl_array = ssht.generate_dl(beta, L)
164+
betas = [0, np.pi / 2.0, np.pi]
165+
for beta in betas:
166+
dl_array = ssht.generate_dl(beta, L)
176167

177-
# Compare to routines in SSHT, which have been validated extensively.
178-
dl = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
168+
# Compare to routines in SSHT, which have been validated extensively.
169+
dl = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
179170

180-
for el in range(0, L):
181-
dl = recursions.risbo_jax.compute_full(dl, beta, L, el)
182-
np.testing.assert_allclose(dl_array[el, :, :], dl, atol=1e-15)
171+
for el in range(0, L):
172+
dl = recursions.risbo_jax.compute_full(dl, beta, L, el)
173+
np.testing.assert_allclose(dl_array[el, :, :], dl, atol=1e-15)
183174

184175

185176
@pytest.mark.parametrize("L", L_to_test)

0 commit comments

Comments
 (0)