Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 70 additions & 24 deletions s2fft/precompute_transforms/fourier_wigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def inverse_transform(
flmn: np.ndarray,
delta: np.ndarray,
DW: np.ndarray,
L: int,
N: int,
reality: bool = False,
Expand All @@ -18,7 +18,7 @@ def inverse_transform(

Args:
flmn (np.ndarray): Wigner coefficients.
delta (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
Wigner d-functions and the corresponding upsampled quadrature weights.
L (int): Harmonic band-limit.
N (int): Azimuthal band-limit.
Expand All @@ -32,6 +32,14 @@ def inverse_transform(
np.ndarray: Pixel-space function sampled on the rotation group.

"""
if sampling.lower() not in ["mw", "mwss"]:
raise ValueError(
f"Fourier-Wigner algorithm does not support {sampling} sampling."
)

# EXTRACT VARIOUS PRECOMPUTES
Delta, _ = DW

# INDEX VALUES
n_start_ind = N - 1 if reality else 0
n_dim = N if reality else 2 * N - 1
Expand All @@ -44,13 +52,13 @@ def inverse_transform(
m = np.arange(-L + 1 - m_offset, L)
n = np.arange(n_start_ind - N + 1, N)

# Calculate fmna = i^(n-m)\sum_L delta^l_am delta^l_an f^l_mn(2l+1)/(8pi^2)
# Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
x = np.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype)
x[m_offset:, m_offset:] = np.einsum(
"nlm,lam,lan,l->amn",
flmn[n_start_ind:],
delta[0],
delta[0][:, :, L - 1 + n],
Delta,
Delta[:, :, L - 1 + n],
(2 * np.arange(L) + 1) / (8 * np.pi**2),
)

Expand All @@ -72,7 +80,7 @@ def inverse_transform(
@partial(jit, static_argnums=(2, 3, 4, 5))
def inverse_transform_jax(
flmn: jnp.ndarray,
delta: jnp.ndarray,
DW: jnp.ndarray,
L: int,
N: int,
reality: bool = False,
Expand All @@ -83,7 +91,7 @@ def inverse_transform_jax(

Args:
flmn (jnp.ndarray): Wigner coefficients.
delta (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
Wigner d-functions and the corresponding upsampled quadrature weights.
L (int): Harmonic band-limit.
N (int): Azimuthal band-limit.
Expand All @@ -97,6 +105,14 @@ def inverse_transform_jax(
jnp.ndarray: Pixel-space function sampled on the rotation group.

"""
if sampling.lower() not in ["mw", "mwss"]:
raise ValueError(
f"Fourier-Wigner algorithm does not support {sampling} sampling."
)

# EXTRACT VARIOUS PRECOMPUTES
Delta, _ = DW

# INDEX VALUES
n_start_ind = N - 1 if reality else 0
n_dim = N if reality else 2 * N - 1
Expand All @@ -109,17 +125,15 @@ def inverse_transform_jax(
m = jnp.arange(-L + 1 - m_offset, L)
n = jnp.arange(n_start_ind - N + 1, N)

# Calculate fmna = i^(n-m)\sum_L delta^l_am delta^l_an f^l_mn(2l+1)/(8pi^2)
x = jnp.zeros((xnlm_size, xnlm_size, n_dim), dtype=flmn.dtype)
# Calculate fmna = i^(n-m)\sum_L Delta^l_am Delta^l_an f^l_mn(2l+1)/(8pi^2)
x = jnp.zeros((xnlm_size, xnlm_size, n_dim), dtype=jnp.complex128)
flmn = jnp.einsum("nlm,l->nlm", flmn, (2 * jnp.arange(L) + 1) / (8 * jnp.pi**2))
x = x.at[m_offset:, m_offset:].set(
jnp.einsum(
"nlm,lam,lan,l->amn",
flmn[n_start_ind:],
delta[0],
delta[0][:, :, L - 1 + n],
(2 * jnp.arange(L) + 1) / (8 * jnp.pi**2),
"nlm,lam,lan->amn", flmn[n_start_ind:], Delta, Delta[:, :, L - 1 + n]
)
)

# APPLY SIGN FUNCTION AND PHASE SHIFT
x = jnp.einsum("amn,m,n,a->nam", x, 1j ** (-m), 1j ** (n), jnp.exp(1j * m * theta0))

Expand All @@ -136,14 +150,19 @@ def inverse_transform_jax(


def forward_transform(
f: np.ndarray, delta: np.ndarray, L: int, N: int, reality: bool, sampling: str
f: np.ndarray,
DW: np.ndarray,
L: int,
N: int,
reality: bool = False,
sampling: str = "mw",
) -> np.ndarray:
"""
Computes the forward Wigner transform using the Fourier decomposition algorithm.

Args:
f (np.ndarray): Function sampled on the rotation group.
delta (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
Wigner d-functions and the corresponding upsampled quadrature weights.
L (int): Harmonic band-limit.
N (int): Azimuthal band-limit.
Expand All @@ -157,6 +176,14 @@ def forward_transform(
np.ndarray: Wigner coefficients of function f.

"""
if sampling.lower() not in ["mw", "mwss"]:
raise ValueError(
f"Fourier-Wigner algorithm does not support {sampling} sampling."
)

# EXTRACT VARIOUS PRECOMPUTES
Delta, Quads = DW

# INDEX VALUES
n_start_ind = N - 1 if reality else 0
m_offset = 1 if sampling.lower() == "mwss" else 0
Expand Down Expand Up @@ -193,14 +220,17 @@ def forward_transform(
x = np.fft.ifft(x, axis=1, norm="forward")

# PERFORM QUADRATURE CONVOLUTION AS FFT REWEIGHTING IN REAL SPACE
x = np.einsum("nbm,b->nbm", x, delta[1])
# NB: Our convention here is conjugate to that of SSHT, in which
# the weights are conjugate but applied flipped and therefore are
# equivalent. To avoid flipping here he simply conjugate the weights.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: he -> we

x = np.einsum("nbm,b->nbm", x, Quads)

# COMPUTE GMM BY FFT
x = np.fft.fft(x, axis=1, norm="forward")
x = np.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]

# Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
x = np.einsum("nam,lam,lan->nlm", x, delta[0], delta[0][:, :, L - 1 + n])
# Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
x = np.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n])
x = np.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n))

# SYMMETRY REFLECT FOR N < 0
Expand All @@ -218,14 +248,19 @@ def forward_transform(

@partial(jit, static_argnums=(2, 3, 4, 5))
def forward_transform_jax(
f: jnp.ndarray, delta: jnp.ndarray, L: int, N: int, reality: bool, sampling: str
f: jnp.ndarray,
DW: jnp.ndarray,
L: int,
N: int,
reality: bool = False,
sampling: str = "mw",
) -> jnp.ndarray:
"""
Computes the forward Wigner transform using the Fourier decomposition algorithm (JAX).

Args:
f (jnp.ndarray): Function sampled on the rotation group.
delta (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
DW (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
Wigner d-functions and the corresponding upsampled quadrature weights.
L (int): Harmonic band-limit.
N (int): Azimuthal band-limit.
Expand All @@ -239,6 +274,14 @@ def forward_transform_jax(
jnp.ndarray: Wigner coefficients of function f.

"""
if sampling.lower() not in ["mw", "mwss"]:
raise ValueError(
f"Fourier-Wigner algorithm does not support {sampling} sampling."
)

# EXTRACT VARIOUS PRECOMPUTES
Delta, Quads = DW

# INDEX VALUES
n_start_ind = N - 1 if reality else 0
m_offset = 1 if sampling.lower() == "mwss" else 0
Expand Down Expand Up @@ -275,14 +318,17 @@ def forward_transform_jax(
x = jnp.fft.ifft(x, axis=1, norm="forward")

# PERFORM QUADRATURE CONVOLUTION AS FFT REWEIGHTING IN REAL SPACE
x = jnp.einsum("nbm,b->nbm", x, delta[1])
# NB: Our convention here is conjugate to that of SSHT, in which
# the weights are conjugate but applied flipped and therefore are
# equivalent. To avoid flipping here he simply conjugate the weights.
x = jnp.einsum("nbm,b->nbm", x, Quads)

# COMPUTE GMM BY FFT
x = jnp.fft.fft(x, axis=1, norm="forward")
x = jnp.fft.fftshift(x, axes=1)[:, L - 1 : 3 * L - 2]

# Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
x = jnp.einsum("nam,lam,lan->nlm", x, delta[0], delta[0][:, :, L - 1 + n])
# Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
x = jnp.einsum("nam,lam,lan->nlm", x, Delta, Delta[:, :, L - 1 + n])
x = jnp.einsum("nbm,m,n->nbm", x, 1j ** (m), 1j ** (-n))

# SYMMETRY REFLECT FOR N < 0
Expand Down
Loading