|
| 1 | +import jax.numpy as jnp |
| 2 | +from jax import jit |
| 3 | +from functools import partial |
| 4 | +from typing import Tuple |
| 5 | + |
| 6 | +from s2fft.recursions.risbo_jax import compute_full |
| 7 | + |
| 8 | + |
| 9 | +@partial(jit, static_argnums=(1, 2)) |
| 10 | +def rotate_flms( |
| 11 | + flm: jnp.ndarray, |
| 12 | + L: int, |
| 13 | + rotation: Tuple[float, float, float], |
| 14 | + dl_array: jnp.ndarray = None, |
| 15 | +) -> jnp.ndarray: |
| 16 | + """Rotates an array of spherical harmonic coefficients by angle rotation. |
| 17 | +
|
| 18 | + Args: |
| 19 | + flm (jnp.ndarray): Array of spherical harmonic coefficients. |
| 20 | + L (int): Harmonic band-limit. |
| 21 | + rotation (Tuple[float, float, float]): Rotation on the sphere (alpha, beta, gamma). |
| 22 | + dl_array (jnp.ndarray, optional): Precomputed array of reduced Wigner d-function |
| 23 | + coefficients, see :func:~`generate_rotate_dls`. Defaults to None. |
| 24 | +
|
| 25 | + Returns: |
| 26 | + jnp.ndarray: Rotated spherical harmonic coefficients with shape [L,2L-1]. |
| 27 | + """ |
| 28 | + |
| 29 | + # Split out angles |
| 30 | + alpha = __exp_array(L, rotation[0]) |
| 31 | + gamma = __exp_array(L, rotation[2]) |
| 32 | + beta = rotation[1] |
| 33 | + |
| 34 | + # Create empty arrays |
| 35 | + flm_rotated = jnp.zeros_like(flm) |
| 36 | + |
| 37 | + dl = ( |
| 38 | + dl_array |
| 39 | + if dl_array != None |
| 40 | + else jnp.zeros((2 * L - 1, 2 * L - 1)).astype(jnp.complex128) |
| 41 | + ) |
| 42 | + |
| 43 | + # Perform rotation |
| 44 | + for el in range(L): |
| 45 | + if dl_array is None: |
| 46 | + dl = compute_full(dl, beta, L, el) |
| 47 | + n_max = min(el, L - 1) |
| 48 | + |
| 49 | + m = jnp.arange(-el, el + 1) |
| 50 | + n = jnp.arange(-n_max, n_max + 1) |
| 51 | + |
| 52 | + flm_rotated = flm_rotated.at[el, L - 1 + m].add( |
| 53 | + jnp.einsum( |
| 54 | + "mn,n->m", |
| 55 | + jnp.einsum( |
| 56 | + "mn,m->mn", |
| 57 | + dl[m + L - 1][:, n + L - 1] |
| 58 | + if dl_array is None |
| 59 | + else dl[el, m + L - 1][:, n + L - 1], |
| 60 | + alpha[m + L - 1], |
| 61 | + optimize=True, |
| 62 | + ), |
| 63 | + gamma[n + L - 1] * flm[el, n + L - 1], |
| 64 | + ) |
| 65 | + ) |
| 66 | + return flm_rotated |
| 67 | + |
| 68 | + |
| 69 | +@partial(jit, static_argnums=(0, 1)) |
| 70 | +def __exp_array(L: int, x: float) -> jnp.ndarray: |
| 71 | + """Private function to generate rotation arrays for alpha/gamma rotations""" |
| 72 | + return jnp.exp(-1j * jnp.arange(-L + 1, L) * x) |
| 73 | + |
| 74 | + |
| 75 | +@partial(jit, static_argnums=(0, 1)) |
| 76 | +def generate_rotate_dls(L: int, beta: float) -> jnp.ndarray: |
| 77 | + """Function which recursively generates the complete plane of reduced |
| 78 | + Wigner d-function coefficients at a given rotation beta. |
| 79 | +
|
| 80 | + Args: |
| 81 | + L (int): Harmonic band-limit. |
| 82 | + beta (float): Rotation on the sphere. |
| 83 | +
|
| 84 | + Returns: |
| 85 | + jnp.ndarray: Complete array of [L, 2L-1,2L-1] Wigner d-function coefficients |
| 86 | + for a fixed rotation beta. |
| 87 | + """ |
| 88 | + dl = jnp.zeros((L, 2 * L - 1, 2 * L - 1)).astype(jnp.float64) |
| 89 | + dl_iter = jnp.zeros((2 * L - 1, 2 * L - 1)).astype(jnp.float64) |
| 90 | + for el in range(L): |
| 91 | + dl_iter = compute_full(dl_iter, beta, L, el) |
| 92 | + dl = dl.at[el].add(dl_iter) |
| 93 | + return dl |
0 commit comments