Skip to content

s2wav.analysis fails for Healpix map input with AssertionError #84

@1cosmologist

Description

@1cosmologist

I am trying to compute directional wavelet transformation of a Healpix map. I have tried using both s2wav.analysis, s2wav.wavelet.flm_to_analysis (with map to flm separately computed with s2fft). I am encountering an AssertionError.

Minimal example:

nside = 128
lmax = 2 * nside 
N = 3

hpx_map = np.ones((12*nside**2,))

filter_bank = sw.filters.filters_directional_vectorised(lmax, N)
wavelet_coeffs, scaling_coeffs = sw.analysis(hpx_map, lmax, N, nside=nside, filters=filter_bank, sampling='healpix')

Fails with the following error message:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[116], line 8
      5 hpx_map = np.ones((12*nside**2,))
      7 filter_bank = sw.filters.filters_directional_vectorised(lmax, N)
----> 8 wavelet_coeffs, scaling_coeffs = sw.analysis(hpx_map, lmax, N, nside=nside, filters=filter_bank, sampling='healpix')

    [... skipping hidden 11 frame]

File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2wav/transforms/wavelet.py:189, in analysis(f, L, N, J_min, lam, spin, sampling, nside, reality, filters, precomps)
    174     Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True)
    175     f_wav_lmn[j - J_min] = (
    176         f_wav_lmn[j - J_min]
    177         .at[::2, L0j:]
   (...)
    185         )
    186     )
    188     f_wav.append(
--> 189         s2fft.wigner.inverse_jax(
    190             f_wav_lmn[j - J_min],
    191             Lj,
    192             Nj,
    193             nside,
    194             sampling,
    195             reality,
    196             precomps[j - J_min],
    197             L0j,
    198         )
    199     )
    201 # Project all harmonic coefficients for each lm onto scaling coefficients
    202 phi = filters[1][:Ls] * jnp.sqrt(4 * jnp.pi / (2 * jnp.arange(Ls) + 1))

    [... skipping hidden 11 frame]

File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/transforms/wigner.py:257, in inverse_jax(flmn, L, N, nside, sampling, reality, precomps, L_lower)
    251     precomps = [p0, p1, p2, p3, p4]
    252     return (-1) ** jnp.abs(spin) * s2fft.inverse_jax(
    253         flm, L, -spin, nside, sampling, False, precomps, False, L_lower
    254     )
    256 fban = fban.at[N - 1 + n_start_ind :].set(
--> 257     vmap(
    258         partial(func, p2=precomps[2][0], p3=precomps[3][0], p4=precomps[4][0]),
    259         in_axes=(0, 0, 0, 0),
    260     )(flmn[N - 1 + n_start_ind :], spins, precomps[0], precomps[1])
    261 )
    262 if reality:
    263     f = jnp.fft.irfft(fban[N - 1 :], 2 * N - 1, axis=0, norm=\"forward\")

    [... skipping hidden 3 frame]

File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/transforms/wigner.py:252, in inverse_jax.<locals>.func(flm, spin, p0, p1, p2, p3, p4)
    250 def func(flm, spin, p0, p1, p2, p3, p4):
    251     precomps = [p0, p1, p2, p3, p4]
--> 252     return (-1) ** jnp.abs(spin) * s2fft.inverse_jax(
    253         flm, L, -spin, nside, sampling, False, precomps, False, L_lower
    254     )

    [... skipping hidden 11 frame]

File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/transforms/spherical.py:319, in inverse_jax(flm, L, spin, nside, sampling, reality, precomps, spmd, L_lower)
    315     ftm = ftm.at[:, m_offset : L - 1 + m_offset].set(
    316         jnp.flip(jnp.conj(ftm[:, L - 1 + m_offset + 1 :]), axis=-1)
    317     )
    318 if sampling.lower() == \"healpix\":
--> 319     return hp.healpix_ifft(ftm, L, nside, \"jax\")
    320 else:
    321     ftm = jnp.conj(jnp.fft.ifftshift(ftm, axes=1))

File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/utils/healpix_ffts.py:398, in healpix_ifft(ftm, L, nside, method, reality)
    368 def healpix_ifft(
    369     ftm: np.ndarray,
    370     L: int,
   (...)
    373     reality: bool = False,
    374 ) -> np.ndarray:
    375     \"\"\"Wrapper function for the Inverse Fast Fourier Transform with spectral folding
    376     in the polar regions to mitigate aliasing.
    377 
   (...)
    396         np.ndarray: HEALPix pixel-space array.
    397     \"\"\"
--> 398     assert L >= 2 * nside
    399     if method.lower() == \"numpy\":
    400         return healpix_ifft_numpy(ftm, L, nside, reality)

AssertionError: 

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions