|
1 | | -from jax import jit |
| 1 | +from jax import jit, vmap |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | import jax.numpy as jnp |
@@ -220,23 +220,47 @@ def healpix_fft_jax(f: jnp.ndarray, L: int, nside: int, reality: bool) -> jnp.nd |
220 | 220 | Returns: |
221 | 221 | jnp.ndarray: Array of Fourier coefficients for all latitudes. |
222 | 222 | """ |
223 | | - ntheta = samples.ntheta(L, "healpix", nside) |
224 | | - index = 0 |
225 | | - ftm_rows = [] |
226 | | - for t in range(ntheta): |
227 | | - nphi = samples.nphi_ring(t, nside) |
| 223 | + |
| 224 | + def f_chunks_to_ftm_rows(f_chunks, nphi): |
228 | 225 | if reality and nphi == 2 * L: |
229 | | - fm_chunk = jnp.zeros(nphi, dtype=jnp.complex128) |
230 | | - fm_chunk = fm_chunk.at[nphi // 2 :].set( |
231 | | - jnp.fft.rfft(jnp.real(f[index : index + nphi]), norm="backward")[:-1] |
| 226 | + fm_chunks = jnp.concatenate( |
| 227 | + ( |
| 228 | + jnp.zeros((f_chunks.shape[0], nphi // 2)), |
| 229 | + jnp.fft.rfft(jnp.real(f_chunks), norm="backward")[:, :-1], |
| 230 | + ), |
| 231 | + axis=1, |
232 | 232 | ) |
233 | 233 | else: |
234 | | - fm_chunk = jnp.fft.fftshift( |
235 | | - jnp.fft.fft(f[index : index + nphi], norm="backward") |
| 234 | + fm_chunks = jnp.fft.fftshift( |
| 235 | + jnp.fft.fft(f_chunks, norm="backward"), axes=-1 |
236 | 236 | ) |
237 | | - ftm_rows.append(spectral_periodic_extension_jax(fm_chunk, L)) |
238 | | - index += nphi |
239 | | - return jnp.stack(ftm_rows) |
| 237 | + return vmap(spectral_periodic_extension_jax, (0, None))(fm_chunks, L) |
| 238 | + |
| 239 | + # Process f chunks corresponding to pairs of polar theta rings with the same number |
| 240 | + # of phi samples together to reduce size of unrolled traced computational graph |
| 241 | + ftm_rows_polar = [] |
| 242 | + start_index, end_index = 0, 12 * nside**2 |
| 243 | + for t in range(0, nside - 1): |
| 244 | + nphi = 4 * (t + 1) |
| 245 | + f_chunks = jnp.stack( |
| 246 | + (f[start_index : start_index + nphi], f[end_index - nphi : end_index]) |
| 247 | + ) |
| 248 | + ftm_rows_polar.append(f_chunks_to_ftm_rows(f_chunks, nphi)) |
| 249 | + start_index, end_index = start_index + nphi, end_index - nphi |
| 250 | + ftm_rows_polar = jnp.stack(ftm_rows_polar) |
| 251 | + # Process all f chunks for the equal sized equatorial theta rings together |
| 252 | + nphi = 4 * nside |
| 253 | + f_chunks_equatorial = f[start_index:end_index].reshape((-1, nphi)) |
| 254 | + ftm_rows_equatorial = f_chunks_to_ftm_rows(f_chunks_equatorial, nphi) |
| 255 | + # Concatenate Fourier coefficients for all latitudes, reversing second polar set to |
| 256 | + # account for processing order |
| 257 | + return jnp.concatenate( |
| 258 | + ( |
| 259 | + ftm_rows_polar[:, 0], |
| 260 | + ftm_rows_equatorial, |
| 261 | + ftm_rows_polar[::-1, 1], |
| 262 | + ) |
| 263 | + ) |
240 | 264 |
|
241 | 265 |
|
242 | 266 | def healpix_ifft( |
@@ -336,28 +360,37 @@ def healpix_ifft_jax( |
336 | 360 | Returns: |
337 | 361 | jnp.ndarray: HEALPix pixel-space array. |
338 | 362 | """ |
339 | | - f = jnp.zeros( |
340 | | - samples.f_shape(sampling="healpix", nside=nside), dtype=jnp.complex128 |
341 | | - ) |
342 | | - ntheta = ftm.shape[0] |
343 | | - index = 0 |
344 | 363 |
|
345 | | - for t in range(ntheta): |
346 | | - nphi = samples.nphi_ring(t, nside) |
347 | | - fm_chunk = ftm[t] if nphi == 2 * L else spectral_folding_jax(ftm[t], nphi, L) |
| 364 | + def ftm_rows_to_f_chunks(ftm_rows, nphi): |
| 365 | + fm_chunks = ( |
| 366 | + ftm_rows |
| 367 | + if nphi == 2 * L |
| 368 | + else vmap(spectral_folding_jax, (0, None, None))(ftm_rows, nphi, L) |
| 369 | + ) |
348 | 370 | if reality and nphi == 2 * L: |
349 | | - f = f.at[index : index + nphi].set( |
350 | | - jnp.fft.irfft(fm_chunk[nphi // 2 :], nphi, norm="forward") |
351 | | - ) |
| 371 | + return jnp.fft.irfft(fm_chunks[:, nphi // 2 :], nphi, norm="forward") |
352 | 372 | else: |
353 | | - f = f.at[index : index + nphi].set( |
354 | | - jnp.conj( |
355 | | - jnp.fft.fft(jnp.fft.ifftshift(jnp.conj(fm_chunk)), norm="backward") |
| 373 | + return jnp.conj( |
| 374 | + jnp.fft.fft( |
| 375 | + jnp.fft.ifftshift(jnp.conj(fm_chunks), axes=-1), norm="backward" |
356 | 376 | ) |
357 | 377 | ) |
358 | 378 |
|
359 | | - index += nphi |
360 | | - return f |
| 379 | + # Process ftm rows corresponding to pairs of polar theta rings with the same number |
| 380 | + # of phi samples together to reduce size of unrolled traced computational graph |
| 381 | + f_chunks_polar = [ |
| 382 | + ftm_rows_to_f_chunks(jnp.stack((ftm[t], ftm[-(t + 1)])), 4 * (t + 1)) |
| 383 | + for t in range(nside - 1) |
| 384 | + ] |
| 385 | + # Process all ftm rows for the equal sized equatorial theta rings together |
| 386 | + f_chunks_equatorial = ftm_rows_to_f_chunks(ftm[nside - 1 : 3 * nside], 4 * nside) |
| 387 | + # Concatenate f chunks for all theta rings together, reversing second polar set |
| 388 | + # to account for processing order |
| 389 | + return jnp.concatenate( |
| 390 | + [f_chunks_polar[t][0] for t in range(nside - 1)] |
| 391 | + + [f_chunks_equatorial.flatten()] |
| 392 | + + [f_chunks_polar[t][1] for t in reversed(range(nside - 1))] |
| 393 | + ) |
361 | 394 |
|
362 | 395 |
|
363 | 396 | def p2phi_rings(t: np.ndarray, nside: int) -> np.ndarray: |
|
0 commit comments