|
| 1 | +from typing import Tuple |
1 | 2 | from warnings import warn |
2 | 3 |
|
3 | 4 | import jax |
@@ -610,6 +611,62 @@ def wigner_kernel_jax( |
610 | 611 | return dl |
611 | 612 |
|
612 | 613 |
|
| 614 | +def fourier_wigner_kernel(L: int) -> Tuple[np.ndarray, np.ndarray]: |
| 615 | + """ |
| 616 | + Computes Fourier coefficients of the reduced Wigner d-functions and quadrature |
| 617 | + weights upsampled for the forward Fourier-Wigner transform. |
| 618 | +
|
| 619 | + Args: |
| 620 | + L (int): Harmonic band-limit. |
| 621 | +
|
| 622 | + Returns: |
| 623 | + Tuple[np.ndarray, np.ndarray]: Tuple of delta Fourier coefficients and weights. |
| 624 | +
|
| 625 | + """ |
| 626 | + # Calculate deltas (np.pi/2 Fourier coefficients of Wigner matrices) |
| 627 | + deltas = np.zeros((L, 2 * L - 1, 2 * L - 1), dtype=np.float64) |
| 628 | + d = np.zeros((2 * L - 1, 2 * L - 1), dtype=np.float64) |
| 629 | + for el in range(L): |
| 630 | + d = recursions.risbo.compute_full(d, np.pi / 2, L, el) |
| 631 | + deltas[el] = d |
| 632 | + |
| 633 | + # Calculate upsampled quadrature weights |
| 634 | + w = np.zeros(4 * L - 3, dtype=np.complex128) |
| 635 | + for mm in range(-2 * (L - 1), 2 * (L - 1) + 1): |
| 636 | + w[mm + 2 * (L - 1)] = quadrature.mw_weights(-mm) |
| 637 | + w = np.fft.ifft(np.fft.ifftshift(w), norm="forward") |
| 638 | + |
| 639 | + return deltas, w |
| 640 | + |
| 641 | + |
| 642 | +def fourier_wigner_kernel_jax(L: int) -> Tuple[jnp.ndarray, jnp.ndarray]: |
| 643 | + """ |
| 644 | + Computes Fourier coefficients of the reduced Wigner d-functions and quadrature |
| 645 | + weights upsampled for the forward Fourier-Wigner transform (JAX implementation). |
| 646 | +
|
| 647 | + Args: |
| 648 | + L (int): Harmonic band-limit. |
| 649 | +
|
| 650 | + Returns: |
| 651 | + Tuple[jnp.ndarray, jnp.ndarray]: Tuple of delta Fourier coefficients and weights. |
| 652 | +
|
| 653 | + """ |
| 654 | + # Calculate deltas (np.pi/2 Fourier coefficients of Wigner matrices) |
| 655 | + deltas = jnp.zeros((L, 2 * L - 1, 2 * L - 1), dtype=jnp.float64) |
| 656 | + d = jnp.zeros((2 * L - 1, 2 * L - 1), dtype=jnp.float64) |
| 657 | + for el in range(L): |
| 658 | + d = recursions.risbo_jax.compute_full(d, jnp.pi / 2, L, el) |
| 659 | + deltas = deltas.at[el].set(d) |
| 660 | + |
| 661 | + # Calculate upsampled quadrature weights |
| 662 | + w = jnp.zeros(4 * L - 3, dtype=jnp.complex128) |
| 663 | + for mm in range(-2 * (L - 1), 2 * (L - 1) + 1): |
| 664 | + w = w.at[mm + 2 * (L - 1)].set(quadrature_jax.mw_weights(-mm)) |
| 665 | + w = jnp.fft.ifft(jnp.fft.ifftshift(w), norm="forward") |
| 666 | + |
| 667 | + return deltas, w |
| 668 | + |
| 669 | + |
613 | 670 | def healpix_phase_shifts(L: int, nside: int, forward: bool = False) -> np.ndarray: |
614 | 671 | r""" |
615 | 672 | Generates a phase shift vector for HEALPix for all :math:`\theta` rings. |
|
0 commit comments