Skip to content

Inconsistent results for pure JAX implementation of inverse Wigner transformation #209

@ElisR

Description

@ElisR

Hi! Firstly, thanks a lot for developing this package! I'm not in astro, but the package has been just what I needed in a project I'm working on.

The issue I'm having is related to the JAX implementation of the inverse Wigner transformation s2fft.wigner.inverse_jax. I was lucky to stumble upon the SSHT version of this function that allowed me to implement what I needed, albeit on the CPU.

Below, I have added a minimum (non-)working example to illustrate the kind of use case I have.
Apologies that it is still slightly verbose.
I haven't delved deep into the library code to try and debug, but from the scales it looks like some catastrophic floating point errors are happening (beyond aliasing artefacts, which is what I originally assumed).

"""Minimum example showing `s2fft`'s inverse Wigner transformation not working reliably."""

from collections.abc import Callable
from typing import TypeAlias

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax.scipy.spatial import transform

# Allowing 64-bit float before importing s2fft
jax.config.update("jax_enable_x64", True)
import s2fft  # noqa: E402  pylint: disable=C0413
from s2fft import partial, sampling  # noqa: E402  pylint: disable=C0413

ScalarFunction: TypeAlias = Callable[[jax.Array], jax.Array]


S2FFT_ANGLE_SAMPLER: str = "dh"
S2FFT_REALITY: bool = False


def dxy_orbital(coords: jax.Array) -> jax.Array:
    """(Arbitrary) function that looks like a 3dxy orbital."""
    # Scaling coordinates to get shape mostly contained in frame.
    coords *= 8.0
    x_s = coords[..., 0]
    y_s = coords[..., 1]
    r_s = jnp.linalg.vector_norm(coords, axis=-1)
    return jnp.exp(-r_s / 2) * x_s * y_s


def _generate_samples(func: ScalarFunction, thetas: jax.Array, phis: jax.Array) -> jax.Array:
    """Generate a signal using the given sampling angles.

    Axes will correspond to `(theta, phi)`.
    """
    x_s = jnp.sin(thetas)[:, None] * jnp.cos(phis)[None, :]
    y_s = jnp.sin(thetas)[:, None] * jnp.sin(phis)[None, :]
    z_s = jnp.cos(thetas)[:, None] + 0 * x_s
    coords = jnp.stack([x_s, y_s, z_s], axis=-1)
    return func(coords)


def get_correlation(
    func: ScalarFunction, rotation: jax.Array, lmax: int, *, use_ssht: bool
) -> jax.Array:
    """Show that `jax` wrapper for inverse Wigner is not behaving reliably."""

    def _rotated_func(func: ScalarFunction, rotation: jax.Array, coords: jax.Array) -> jax.Array:
        """Wrap a scalar function with its rotated version."""
        rotated_coords = jnp.einsum("ij,...j->...i", rotation.T, coords)
        return func(rotated_coords)

    func_rotated = partial(_rotated_func, func, rotation)

    thetas = jnp.array(sampling.s2_samples.thetas(lmax, sampling=S2FFT_ANGLE_SAMPLER))
    phis = jnp.array(sampling.s2_samples.phis_equiang(lmax, sampling=S2FFT_ANGLE_SAMPLER))

    def _get_coeffs(func: ScalarFunction) -> jax.Array:
        sampled = _generate_samples(func, thetas, phis)
        coeffs = s2fft.forward_jax(
            sampled, lmax, reality=S2FFT_REALITY, sampling=S2FFT_ANGLE_SAMPLER
        )
        return coeffs

    func_coeffs = _get_coeffs(func)
    func_rotated_coeffs = _get_coeffs(func_rotated)

    outer_product = (
        jnp.matrix_transpose(jnp.conjugate(func_rotated_coeffs))[:, :, None]
        * func_coeffs[None, :, :]
    )
    assert outer_product.dtype == jnp.complex128

    # This is the part that varies
    inverse_args = (outer_product, lmax, lmax)
    inverse_kwargs = {"reality": S2FFT_REALITY, "sampling": S2FFT_ANGLE_SAMPLER}
    so3_correlation = (
        s2fft.wigner.inverse_jax_ssht(*inverse_args, **inverse_kwargs)
        if use_ssht
        else s2fft.wigner.inverse_jax(*inverse_args, **inverse_kwargs)
    )
    return jnp.real(so3_correlation)


def plot_strange_results() -> None:
    """Plot strange results from the above function.."""
    print(f"{jax.config.read('jax_enable_x64')=}")

    # Get scores
    rotation = transform.Rotation.from_rotvec(jnp.array([0, 0, jnp.pi / 2])).as_matrix()
    lmax = 64
    ssht_score = get_correlation(dxy_orbital, rotation, lmax, use_ssht=True)
    jax_score = get_correlation(dxy_orbital, rotation, lmax, use_ssht=False)

    # Choose angles for plotting
    best_index = jnp.unravel_index(jnp.argmax(ssht_score), jnp.shape(ssht_score))
    best_gamma, best_beta, best_alpha = best_index
    betas = jnp.array(sampling.s2_samples.thetas(lmax, sampling=S2FFT_ANGLE_SAMPLER))
    alphas = gammas = jnp.array(
        sampling.s2_samples.phis_equiang(lmax, sampling=S2FFT_ANGLE_SAMPLER)
    )

    # Plot output
    fig, axes = plt.subplots(ncols=3, nrows=2, squeeze=False)
    plt.subplots_adjust(left=0.1, right=0.9, top=0.95, bottom=0.05, hspace=0.4, wspace=0.3)
    for i, (score, name) in enumerate([(ssht_score, "SSHT"), (jax_score, "JAX")]):
        axes[i, 0].plot(alphas, score[best_gamma, best_beta, :])
        axes[i, 0].title.set_text(f"Corr vs α for {name}")
        axes[i, 1].plot(betas, score[best_gamma, :, best_alpha])
        axes[i, 1].title.set_text(f"Corr vs β for {name}")
        axes[i, 2].plot(gammas, score[:, best_beta, best_alpha])
        axes[i, 2].title.set_text(f"Corr vs γ for {name}")
    fig.savefig("strange_s2fft_inverse_wigner.png", bbox_inches="tight", dpi=300)
    plt.close(fig)


if __name__ == "__main__":
    plot_strange_results()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions