Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/api/precompute_transforms/fourier_wigner.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
:html_theme.sidebar_secondary.remove:

**************************
Fourier-Wigner Transform
**************************
.. automodule:: s2fft.precompute_transforms.fourier_wigner
:members:
20 changes: 20 additions & 0 deletions docs/api/precompute_transforms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,21 @@ Precompute Functions
* - :func:`~s2fft.precompute_transforms.wigner.forward_transform_torch`
- Forward Wigner transform (Torch)

.. list-table:: Fourier-Wigner transforms.
:widths: 25 25
:header-rows: 1

* - Function Name
- Description
* - :func:`~s2fft.precompute_transforms.fourier_wigner.inverse_transform`
- Inverse Wigner transform with Fourier method (NumPy)
* - :func:`~s2fft.precompute_transforms.fourier_wigner.inverse_transform_jax`
- Inverse Wigner transform with Fourier method (JAX)
* - :func:`~s2fft.precompute_transforms.fourier_wigner.forward_transform`
- Forward Wigner transform with Fourier method (NumPy)
* - :func:`~s2fft.precompute_transforms.fourier_wigner.forward_transform_jax`
- Forward Wigner transform with Fourier method (JAX)

.. list-table:: Constructing Kernels for precompute transforms.
:widths: 25 25
:header-rows: 1
Expand All @@ -64,6 +79,10 @@ Precompute Functions
- Builds a kernel including quadrature weights and Wigner-D coefficients for spherical harmonic transform (JAX).
* - :func:`~s2fft.precompute_transforms.construct.wigner_kernel_jax`
- Builds a kernel including quadrature weights and Wigner-D coefficients for Wigner transform (JAX).
* - :func:`~s2fft.precompute_transforms.construct.fourier_wigner_kernel`
- Builds a kernel including quadrature weights and Fourier coefficienfs of Wigner d-functions
* - :func:`~s2fft.precompute_transforms.construct.fourier_wigner_kernel_jax`
- Builds a kernel including quadrature weights and Fourier coefficienfs of Wigner d-functions (JAX).
* - :func:`~s2fft.precompute_transforms.construct.healpix_phase_shifts`
- Builds a vector of corresponding phase shifts for each HEALPix latitudinal ring.

Expand All @@ -76,4 +95,5 @@ Precompute Functions
alt_construct
spin_spherical
wigner
fourier_wigner

2 changes: 1 addition & 1 deletion s2fft/precompute_transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import construct, spherical, wigner
from . import construct, fourier_wigner, spherical, wigner
57 changes: 57 additions & 0 deletions s2fft/precompute_transforms/construct.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Tuple
from warnings import warn

import jax
Expand Down Expand Up @@ -610,6 +611,62 @@ def wigner_kernel_jax(
return dl


def fourier_wigner_kernel(L: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Computes Fourier coefficients of the reduced Wigner d-functions and quadrature
weights upsampled for the forward Fourier-Wigner transform.

Args:
L (int): Harmonic band-limit.

Returns:
Tuple[np.ndarray, np.ndarray]: Tuple of delta Fourier coefficients and weights.

"""
# Calculate deltas (np.pi/2 Fourier coefficients of Wigner matrices)
deltas = np.zeros((L, 2 * L - 1, 2 * L - 1), dtype=np.float64)
d = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64)
for el in range(L):
d = recursions.risbo.compute_full(d, np.pi / 2, L, el)
deltas[el] = d

# Calculate upsampled quadrature weights
w = np.zeros(4 * L - 3, dtype=np.complex128)
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
w[mm + 2 * (L - 1)] = quadrature.mw_weights(-mm)
w = np.fft.ifft(np.fft.ifftshift(w), norm="forward")

return deltas, w


def fourier_wigner_kernel_jax(L: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Computes Fourier coefficients of the reduced Wigner d-functions and quadrature
weights upsampled for the forward Fourier-Wigner transform (JAX implementation).

Args:
L (int): Harmonic band-limit.

Returns:
Tuple[jnp.ndarray, jnp.ndarray]: Tuple of delta Fourier coefficients and weights.

"""
# Calculate deltas (np.pi/2 Fourier coefficients of Wigner matrices)
deltas = jnp.zeros((L, 2 * L - 1, 2 * L - 1), dtype=jnp.float64)
d = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64)
for el in range(L):
d = recursions.risbo_jax.compute_full(d, jnp.pi / 2, L, el)
deltas = deltas.at[el].set(d)

# Calculate upsampled quadrature weights
w = jnp.zeros(4 * L - 3, dtype=jnp.complex128)
for mm in range(-2 * (L - 1), 2 * (L - 1) + 1):
w = w.at[mm + 2 * (L - 1)].set(quadrature_jax.mw_weights(-mm))
w = jnp.fft.ifft(jnp.fft.ifftshift(w), norm="forward")

return deltas, w


def healpix_phase_shifts(L: int, nside: int, forward: bool = False) -> np.ndarray:
r"""
Generates a phase shift vector for HEALPix for all :math:`\theta` rings.
Expand Down
Loading
Loading