Skip to content

Commit 84fbf0b

Browse files
committed
Reducing compile time of JAX HEALPix (I)FFT implementations
1 parent 4ef9c67 commit 84fbf0b

File tree

1 file changed

+63
-30
lines changed

1 file changed

+63
-30
lines changed

s2fft/utils/healpix_ffts.py

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from jax import jit
1+
from jax import jit, vmap
22

33
import numpy as np
44
import jax.numpy as jnp
@@ -220,23 +220,47 @@ def healpix_fft_jax(f: jnp.ndarray, L: int, nside: int, reality: bool) -> jnp.nd
220220
Returns:
221221
jnp.ndarray: Array of Fourier coefficients for all latitudes.
222222
"""
223-
ntheta = samples.ntheta(L, "healpix", nside)
224-
index = 0
225-
ftm_rows = []
226-
for t in range(ntheta):
227-
nphi = samples.nphi_ring(t, nside)
223+
224+
def f_chunks_to_ftm_rows(f_chunks, nphi):
228225
if reality and nphi == 2 * L:
229-
fm_chunk = jnp.zeros(nphi, dtype=jnp.complex128)
230-
fm_chunk = fm_chunk.at[nphi // 2 :].set(
231-
jnp.fft.rfft(jnp.real(f[index : index + nphi]), norm="backward")[:-1]
226+
fm_chunks = jnp.concatenate(
227+
(
228+
jnp.zeros((f_chunks.shape[0], nphi // 2)),
229+
jnp.fft.rfft(jnp.real(f_chunks), norm="backward")[:, :-1],
230+
),
231+
axis=1,
232232
)
233233
else:
234-
fm_chunk = jnp.fft.fftshift(
235-
jnp.fft.fft(f[index : index + nphi], norm="backward")
234+
fm_chunks = jnp.fft.fftshift(
235+
jnp.fft.fft(f_chunks, norm="backward"), axes=-1
236236
)
237-
ftm_rows.append(spectral_periodic_extension_jax(fm_chunk, L))
238-
index += nphi
239-
return jnp.stack(ftm_rows)
237+
return vmap(spectral_periodic_extension_jax, (0, None))(fm_chunks, L)
238+
239+
# Process f chunks corresponding to pairs of polar theta rings with the same number
240+
# of phi samples together to reduce size of unrolled traced computational graph
241+
ftm_rows_polar = []
242+
start_index, end_index = 0, 12 * nside**2
243+
for t in range(0, nside - 1):
244+
nphi = 4 * (t + 1)
245+
f_chunks = jnp.stack(
246+
(f[start_index : start_index + nphi], f[end_index - nphi : end_index])
247+
)
248+
ftm_rows_polar.append(f_chunks_to_ftm_rows(f_chunks, nphi))
249+
start_index, end_index = start_index + nphi, end_index - nphi
250+
ftm_rows_polar = jnp.stack(ftm_rows_polar)
251+
# Process all f chunks for the equal sized equatorial theta rings together
252+
nphi = 4 * nside
253+
f_chunks_equatorial = f[start_index:end_index].reshape((-1, nphi))
254+
ftm_rows_equatorial = f_chunks_to_ftm_rows(f_chunks_equatorial, nphi)
255+
# Concatenate Fourier coefficients for all latitudes, reversing second polar set to
256+
# account for processing order
257+
return jnp.concatenate(
258+
(
259+
ftm_rows_polar[:, 0],
260+
ftm_rows_equatorial,
261+
ftm_rows_polar[::-1, 1],
262+
)
263+
)
240264

241265

242266
def healpix_ifft(
@@ -336,28 +360,37 @@ def healpix_ifft_jax(
336360
Returns:
337361
jnp.ndarray: HEALPix pixel-space array.
338362
"""
339-
f = jnp.zeros(
340-
samples.f_shape(sampling="healpix", nside=nside), dtype=jnp.complex128
341-
)
342-
ntheta = ftm.shape[0]
343-
index = 0
344363

345-
for t in range(ntheta):
346-
nphi = samples.nphi_ring(t, nside)
347-
fm_chunk = ftm[t] if nphi == 2 * L else spectral_folding_jax(ftm[t], nphi, L)
364+
def ftm_rows_to_f_chunks(ftm_rows, nphi):
365+
fm_chunks = (
366+
ftm_rows
367+
if nphi == 2 * L
368+
else vmap(spectral_folding_jax, (0, None, None))(ftm_rows, nphi, L)
369+
)
348370
if reality and nphi == 2 * L:
349-
f = f.at[index : index + nphi].set(
350-
jnp.fft.irfft(fm_chunk[nphi // 2 :], nphi, norm="forward")
351-
)
371+
return jnp.fft.irfft(fm_chunks[:, nphi // 2 :], nphi, norm="forward")
352372
else:
353-
f = f.at[index : index + nphi].set(
354-
jnp.conj(
355-
jnp.fft.fft(jnp.fft.ifftshift(jnp.conj(fm_chunk)), norm="backward")
373+
return jnp.conj(
374+
jnp.fft.fft(
375+
jnp.fft.ifftshift(jnp.conj(fm_chunks), axes=-1), norm="backward"
356376
)
357377
)
358378

359-
index += nphi
360-
return f
379+
# Process ftm rows corresponding to pairs of polar theta rings with the same number
380+
# of phi samples together to reduce size of unrolled traced computational graph
381+
f_chunks_polar = [
382+
ftm_rows_to_f_chunks(jnp.stack((ftm[t], ftm[-(t + 1)])), 4 * (t + 1))
383+
for t in range(nside - 1)
384+
]
385+
# Process all ftm rows for the equal sized equatorial theta rings together
386+
f_chunks_equatorial = ftm_rows_to_f_chunks(ftm[nside - 1 : 3 * nside], 4 * nside)
387+
# Concatenate f chunks for all theta rings together, reversing second polar set
388+
# to account for processing order
389+
return jnp.concatenate(
390+
[f_chunks_polar[t][0] for t in range(nside - 1)]
391+
+ [f_chunks_equatorial.flatten()]
392+
+ [f_chunks_polar[t][1] for t in reversed(range(nside - 1))]
393+
)
361394

362395

363396
def p2phi_rings(t: np.ndarray, nside: int) -> np.ndarray:

0 commit comments

Comments
 (0)