-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
I am trying to compute directional wavelet transformation of a Healpix map. I have tried using both s2wav.analysis, s2wav.wavelet.flm_to_analysis (with map to flm separately computed with s2fft). I am encountering an AssertionError.
Minimal example:
nside = 128
lmax = 2 * nside
N = 3
hpx_map = np.ones((12*nside**2,))
filter_bank = sw.filters.filters_directional_vectorised(lmax, N)
wavelet_coeffs, scaling_coeffs = sw.analysis(hpx_map, lmax, N, nside=nside, filters=filter_bank, sampling='healpix')
Fails with the following error message:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[116], line 8
5 hpx_map = np.ones((12*nside**2,))
7 filter_bank = sw.filters.filters_directional_vectorised(lmax, N)
----> 8 wavelet_coeffs, scaling_coeffs = sw.analysis(hpx_map, lmax, N, nside=nside, filters=filter_bank, sampling='healpix')
[... skipping hidden 11 frame]
File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2wav/transforms/wavelet.py:189, in analysis(f, L, N, J_min, lam, spin, sampling, nside, reality, filters, precomps)
174 Lj, Nj, L0j = samples.LN_j(L, j, N, lam, True)
175 f_wav_lmn[j - J_min] = (
176 f_wav_lmn[j - J_min]
177 .at[::2, L0j:]
(...)
185 )
186 )
188 f_wav.append(
--> 189 s2fft.wigner.inverse_jax(
190 f_wav_lmn[j - J_min],
191 Lj,
192 Nj,
193 nside,
194 sampling,
195 reality,
196 precomps[j - J_min],
197 L0j,
198 )
199 )
201 # Project all harmonic coefficients for each lm onto scaling coefficients
202 phi = filters[1][:Ls] * jnp.sqrt(4 * jnp.pi / (2 * jnp.arange(Ls) + 1))
[... skipping hidden 11 frame]
File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/transforms/wigner.py:257, in inverse_jax(flmn, L, N, nside, sampling, reality, precomps, L_lower)
251 precomps = [p0, p1, p2, p3, p4]
252 return (-1) ** jnp.abs(spin) * s2fft.inverse_jax(
253 flm, L, -spin, nside, sampling, False, precomps, False, L_lower
254 )
256 fban = fban.at[N - 1 + n_start_ind :].set(
--> 257 vmap(
258 partial(func, p2=precomps[2][0], p3=precomps[3][0], p4=precomps[4][0]),
259 in_axes=(0, 0, 0, 0),
260 )(flmn[N - 1 + n_start_ind :], spins, precomps[0], precomps[1])
261 )
262 if reality:
263 f = jnp.fft.irfft(fban[N - 1 :], 2 * N - 1, axis=0, norm=\"forward\")
[... skipping hidden 3 frame]
File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/transforms/wigner.py:252, in inverse_jax.<locals>.func(flm, spin, p0, p1, p2, p3, p4)
250 def func(flm, spin, p0, p1, p2, p3, p4):
251 precomps = [p0, p1, p2, p3, p4]
--> 252 return (-1) ** jnp.abs(spin) * s2fft.inverse_jax(
253 flm, L, -spin, nside, sampling, False, precomps, False, L_lower
254 )
[... skipping hidden 11 frame]
File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/transforms/spherical.py:319, in inverse_jax(flm, L, spin, nside, sampling, reality, precomps, spmd, L_lower)
315 ftm = ftm.at[:, m_offset : L - 1 + m_offset].set(
316 jnp.flip(jnp.conj(ftm[:, L - 1 + m_offset + 1 :]), axis=-1)
317 )
318 if sampling.lower() == \"healpix\":
--> 319 return hp.healpix_ifft(ftm, L, nside, \"jax\")
320 else:
321 ftm = jnp.conj(jnp.fft.ifftshift(ftm, axes=1))
File /pscratch/sd/s/shamikg/cmbenv/master-0.0.1/conda/lib/python3.10/site-packages/s2fft/utils/healpix_ffts.py:398, in healpix_ifft(ftm, L, nside, method, reality)
368 def healpix_ifft(
369 ftm: np.ndarray,
370 L: int,
(...)
373 reality: bool = False,
374 ) -> np.ndarray:
375 \"\"\"Wrapper function for the Inverse Fast Fourier Transform with spectral folding
376 in the polar regions to mitigate aliasing.
377
(...)
396 np.ndarray: HEALPix pixel-space array.
397 \"\"\"
--> 398 assert L >= 2 * nside
399 if method.lower() == \"numpy\":
400 return healpix_ifft_numpy(ftm, L, nside, reality)
AssertionError:
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels