-
Notifications
You must be signed in to change notification settings - Fork 14
Closed
Description
Hi! Firstly, thanks a lot for developing this package! I'm not in astro, but the package has been just what I needed in a project I'm working on.
The issue I'm having is related to the JAX implementation of the inverse Wigner transformation s2fft.wigner.inverse_jax. I was lucky to stumble upon the SSHT version of this function that allowed me to implement what I needed, albeit on the CPU.
Below, I have added a minimum (non-)working example to illustrate the kind of use case I have.
Apologies that it is still slightly verbose.
I haven't delved deep into the library code to try and debug, but from the scales it looks like some catastrophic floating point errors are happening (beyond aliasing artefacts, which is what I originally assumed).
"""Minimum example showing `s2fft`'s inverse Wigner transformation not working reliably."""
from collections.abc import Callable
from typing import TypeAlias
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax.scipy.spatial import transform
# Allowing 64-bit float before importing s2fft
jax.config.update("jax_enable_x64", True)
import s2fft # noqa: E402 pylint: disable=C0413
from s2fft import partial, sampling # noqa: E402 pylint: disable=C0413
ScalarFunction: TypeAlias = Callable[[jax.Array], jax.Array]
S2FFT_ANGLE_SAMPLER: str = "dh"
S2FFT_REALITY: bool = False
def dxy_orbital(coords: jax.Array) -> jax.Array:
"""(Arbitrary) function that looks like a 3dxy orbital."""
# Scaling coordinates to get shape mostly contained in frame.
coords *= 8.0
x_s = coords[..., 0]
y_s = coords[..., 1]
r_s = jnp.linalg.vector_norm(coords, axis=-1)
return jnp.exp(-r_s / 2) * x_s * y_s
def _generate_samples(func: ScalarFunction, thetas: jax.Array, phis: jax.Array) -> jax.Array:
"""Generate a signal using the given sampling angles.
Axes will correspond to `(theta, phi)`.
"""
x_s = jnp.sin(thetas)[:, None] * jnp.cos(phis)[None, :]
y_s = jnp.sin(thetas)[:, None] * jnp.sin(phis)[None, :]
z_s = jnp.cos(thetas)[:, None] + 0 * x_s
coords = jnp.stack([x_s, y_s, z_s], axis=-1)
return func(coords)
def get_correlation(
func: ScalarFunction, rotation: jax.Array, lmax: int, *, use_ssht: bool
) -> jax.Array:
"""Show that `jax` wrapper for inverse Wigner is not behaving reliably."""
def _rotated_func(func: ScalarFunction, rotation: jax.Array, coords: jax.Array) -> jax.Array:
"""Wrap a scalar function with its rotated version."""
rotated_coords = jnp.einsum("ij,...j->...i", rotation.T, coords)
return func(rotated_coords)
func_rotated = partial(_rotated_func, func, rotation)
thetas = jnp.array(sampling.s2_samples.thetas(lmax, sampling=S2FFT_ANGLE_SAMPLER))
phis = jnp.array(sampling.s2_samples.phis_equiang(lmax, sampling=S2FFT_ANGLE_SAMPLER))
def _get_coeffs(func: ScalarFunction) -> jax.Array:
sampled = _generate_samples(func, thetas, phis)
coeffs = s2fft.forward_jax(
sampled, lmax, reality=S2FFT_REALITY, sampling=S2FFT_ANGLE_SAMPLER
)
return coeffs
func_coeffs = _get_coeffs(func)
func_rotated_coeffs = _get_coeffs(func_rotated)
outer_product = (
jnp.matrix_transpose(jnp.conjugate(func_rotated_coeffs))[:, :, None]
* func_coeffs[None, :, :]
)
assert outer_product.dtype == jnp.complex128
# This is the part that varies
inverse_args = (outer_product, lmax, lmax)
inverse_kwargs = {"reality": S2FFT_REALITY, "sampling": S2FFT_ANGLE_SAMPLER}
so3_correlation = (
s2fft.wigner.inverse_jax_ssht(*inverse_args, **inverse_kwargs)
if use_ssht
else s2fft.wigner.inverse_jax(*inverse_args, **inverse_kwargs)
)
return jnp.real(so3_correlation)
def plot_strange_results() -> None:
"""Plot strange results from the above function.."""
print(f"{jax.config.read('jax_enable_x64')=}")
# Get scores
rotation = transform.Rotation.from_rotvec(jnp.array([0, 0, jnp.pi / 2])).as_matrix()
lmax = 64
ssht_score = get_correlation(dxy_orbital, rotation, lmax, use_ssht=True)
jax_score = get_correlation(dxy_orbital, rotation, lmax, use_ssht=False)
# Choose angles for plotting
best_index = jnp.unravel_index(jnp.argmax(ssht_score), jnp.shape(ssht_score))
best_gamma, best_beta, best_alpha = best_index
betas = jnp.array(sampling.s2_samples.thetas(lmax, sampling=S2FFT_ANGLE_SAMPLER))
alphas = gammas = jnp.array(
sampling.s2_samples.phis_equiang(lmax, sampling=S2FFT_ANGLE_SAMPLER)
)
# Plot output
fig, axes = plt.subplots(ncols=3, nrows=2, squeeze=False)
plt.subplots_adjust(left=0.1, right=0.9, top=0.95, bottom=0.05, hspace=0.4, wspace=0.3)
for i, (score, name) in enumerate([(ssht_score, "SSHT"), (jax_score, "JAX")]):
axes[i, 0].plot(alphas, score[best_gamma, best_beta, :])
axes[i, 0].title.set_text(f"Corr vs α for {name}")
axes[i, 1].plot(betas, score[best_gamma, :, best_alpha])
axes[i, 1].title.set_text(f"Corr vs β for {name}")
axes[i, 2].plot(gammas, score[:, best_beta, best_alpha])
axes[i, 2].title.set_text(f"Corr vs γ for {name}")
fig.savefig("strange_s2fft_inverse_wigner.png", bbox_inches="tight", dpi=300)
plt.close(fig)
if __name__ == "__main__":
plot_strange_results()Metadata
Metadata
Assignees
Labels
No labels