Skip to content

Commit 64e8b4d

Browse files
committed
fix bugs from issue 183
1 parent 5ccf24b commit 64e8b4d

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

s2fft/recursions/risbo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
55
r"""Compute Wigner-d at argument :math:`\beta` for full plane using
66
Risbo recursion.
77
8-
The Wigner-d plane is computed by recursion over :math:`\ell` (`el`).
8+
The Wigner-d plane is computed by recursion over :math:`\ell`.
99
Thus, for :math:`\ell > 0` the plane must be computed already for
1010
:math:`\ell - 1`. At present, for :math:`\ell = 0` the recusion is initialised.
1111
@@ -19,7 +19,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
1919
el (int): Spherical harmonic degree :math:`\ell`.
2020
2121
Returns:
22-
np.ndarray: Plane of Wigner-d for `el` and `beta`, with full plane computed.
22+
np.ndarray: Plane of Wigner-d for :math:`\ell` and :math:`\beta`, with full plane computed.
2323
"""
2424

2525
_arg_checks(dl, beta, L, el)
@@ -103,7 +103,7 @@ def compute_full(dl: np.ndarray, beta: float, L: int, el: int) -> np.ndarray:
103103

104104

105105
def _arg_checks(dl: np.ndarray, beta: float, L: int, el: int):
106-
"""Check arguments of Risbo functions.
106+
r"""Check arguments of Risbo functions.
107107
108108
Args:
109109
dl (np.ndarray): Wigner-d plane of which to check shape.

s2fft/transforms/spherical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def f_bwd(res, gtm):
296296
ftm = ftm.at[:, m_start_ind + m_offset :].multiply(phase_shifts)
297297

298298
# Perform longitundal Fast Fourier Transforms
299-
ftm *= (-1) ** spin
299+
ftm *= (-1) ** jnp.abs(spin)
300300
if reality:
301301
ftm = ftm.at[:, m_offset : L - 1 + m_offset].set(
302302
jnp.flip(jnp.conj(ftm[:, L - 1 + m_offset + 1 :]), axis=-1)
@@ -657,4 +657,4 @@ def f_bwd(res, glm):
657657
indices = jnp.repeat(jnp.expand_dims(jnp.arange(L), -1), 2 * L - 1, axis=-1)
658658
flm = jnp.where(indices < abs(spin), jnp.zeros_like(flm), flm[..., :])
659659

660-
return flm * (-1) ** spin
660+
return flm * (-1) ** jnp.abs(spin)

s2fft/transforms/wigner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def inverse_jax(
259259
def spherical_loop(n, args):
260260
fban, flmn, lrenorm, vsign, spins = args
261261
fban = fban.at[n].add(
262-
(-1) ** spins[n]
262+
(-1) ** jnp.abs(spins[n])
263263
* s2fft.inverse_jax(
264264
flmn[n],
265265
L,

s2fft/utils/rotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def rotate_flms(
3737
dl = (
3838
dl_array
3939
if dl_array != None
40-
else jnp.zeros((2 * L - 1, 2 * L - 1)).astype(jnp.complex128)
40+
else jnp.zeros((2 * L - 1, 2 * L - 1)).astype(jnp.float64)
4141
)
4242

4343
# Perform rotation

0 commit comments

Comments
 (0)