Skip to content

Commit 6761781

Browse files
committed
Use wrappers for torch HEALPix FFT functions
1 parent 5fde7ad commit 6761781

File tree

1 file changed

+7
-140
lines changed

1 file changed

+7
-140
lines changed

s2fft/utils/healpix_ffts.py

Lines changed: 7 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import jax.numpy as jnp
44
import jaxlib.mlir.ir as ir
55
import numpy as np
6-
import torch
76
from jax import jit, vmap
87

98
# did not find promote_dtypes_complex outside _src
@@ -14,6 +13,7 @@
1413

1514
from s2fft.sampling import s2_samples as samples
1615
from s2fft.utils.jax_primitive import register_primitive
16+
from s2fft.utils.torch_wrapper import wrap_as_torch_function
1717

1818

1919
def spectral_folding(fm: np.ndarray, nphi: int, L: int) -> np.ndarray:
@@ -79,39 +79,7 @@ def spectral_folding_jax(fm: jnp.ndarray, nphi: int, L: int) -> jnp.ndarray:
7979
)
8080

8181

82-
def spectral_folding_torch(fm: torch.tensor, nphi: int, L: int) -> torch.tensor:
83-
"""
84-
Folds higher frequency Fourier coefficients back onto lower frequency
85-
coefficients, i.e. aliasing high frequencies. Torch specific implementation of
86-
:func:`~spectral_folding`.
87-
88-
Args:
89-
fm (torch.tensor): Slice of Fourier coefficients corresponding to ring at latitute t.
90-
91-
nphi (int): Total number of pixel space phi samples for latitude t.
92-
93-
L (int): Harmonic band-limit.
94-
95-
Returns:
96-
torch.tensor: Lower resolution set of aliased Fourier coefficients.
97-
98-
"""
99-
slice_start = L - nphi // 2
100-
slice_stop = slice_start + nphi
101-
ftm_slice = fm[slice_start:slice_stop]
102-
103-
ftm_slice = ftm_slice.put_(
104-
-torch.arange(1, L - nphi // 2 + 1) % nphi,
105-
fm[slice_start - torch.arange(1, L - nphi // 2 + 1)],
106-
accumulate=True,
107-
)
108-
ftm_slice = ftm_slice.put_(
109-
torch.arange(L - nphi // 2) % nphi,
110-
fm[slice_stop + torch.arange(L - nphi // 2)],
111-
accumulate=True,
112-
)
113-
114-
return ftm_slice
82+
spectral_folding_torch = wrap_as_torch_function(spectral_folding_jax)
11583

11684

11785
def spectral_periodic_extension(fm: np.ndarray, nphi: int, L: int) -> np.ndarray:
@@ -174,29 +142,9 @@ def spectral_periodic_extension_jax(fm: jnp.ndarray, L: int) -> jnp.ndarray:
174142
)
175143

176144

177-
def spectral_periodic_extension_torch(fm: torch.tensor, L: int) -> torch.tensor:
178-
"""
179-
Extends lower frequency Fourier coefficients onto higher frequency
180-
coefficients, i.e. imposed periodicity in Fourier space. Based on
181-
:func:`~spectral_periodic_extension`.
182-
183-
Args:
184-
fm (torch.tensor): Slice of Fourier coefficients corresponding to ring at latitute t.
185-
186-
L (int): Harmonic band-limit.
187-
188-
Returns:
189-
torch.tensor: Higher resolution set of periodic Fourier coefficients.
190-
191-
"""
192-
nphi = fm.shape[0]
193-
return torch.concatenate(
194-
(
195-
fm[-torch.arange(L - nphi // 2, 0, -1) % nphi],
196-
fm,
197-
fm[torch.arange(L - (nphi + 1) // 2) % nphi],
198-
)
199-
)
145+
spectral_periodic_extension_torch = wrap_as_torch_function(
146+
spectral_periodic_extension_jax
147+
)
200148

201149

202150
def healpix_fft(
@@ -350,49 +298,7 @@ def f_chunks_to_ftm_rows(f_chunks, nphi):
350298
)
351299

352300

353-
def healpix_fft_torch(
354-
f: torch.tensor, L: int, nside: int, reality: bool
355-
) -> torch.tensor:
356-
"""
357-
Computes the Forward Fast Fourier Transform with spectral back-projection
358-
in the polar regions to manually enforce Fourier periodicity. Torch specific
359-
implementation of :func:`~healpix_fft_numpy`.
360-
361-
Args:
362-
f (torch.tensor): HEALPix pixel-space array.
363-
364-
L (int): Harmonic band-limit.
365-
366-
nside (int): HEALPix Nside resolution parameter.
367-
368-
reality (bool): Whether the signal on the sphere is real. If so,
369-
conjugate symmetry is exploited to reduce computational costs.
370-
371-
Returns:
372-
torch.tensor: Array of Fourier coefficients for all latitudes.
373-
374-
"""
375-
index = 0
376-
ftm = torch.zeros(samples.ftm_shape(L, "healpix", nside), dtype=torch.complex128)
377-
ntheta = ftm.shape[0]
378-
for t in range(ntheta):
379-
nphi = samples.nphi_ring(t, nside)
380-
if reality and nphi == 2 * L:
381-
fm_chunk = torch.zeros(nphi, dtype=torch.complex128)
382-
fm_chunk[nphi // 2 :] = torch.fft.rfft(
383-
torch.real(f[index : index + nphi]), norm="backward"
384-
)[:-1]
385-
else:
386-
fm_chunk = torch.fft.fftshift(
387-
torch.fft.fft(f[index : index + nphi], norm="backward")
388-
)
389-
ftm[t] = (
390-
fm_chunk
391-
if nphi == 2 * L
392-
else spectral_periodic_extension_torch(fm_chunk, L)
393-
)
394-
index += nphi
395-
return ftm
301+
healpix_fft_torch = wrap_as_torch_function(healpix_fft_jax)
396302

397303

398304
def healpix_ifft(
@@ -531,46 +437,7 @@ def ftm_rows_to_f_chunks(ftm_rows, nphi):
531437
)
532438

533439

534-
def healpix_ifft_torch(
535-
ftm: torch.tensor, L: int, nside: int, reality: bool
536-
) -> torch.tensor:
537-
"""
538-
Computes the Inverse Fast Fourier Transform with spectral folding in the polar
539-
regions to mitigate aliasing. Torch specific implementation of
540-
:func:`~healpix_ifft_numpy`.
541-
542-
Args:
543-
ftm (torch.tensor): Array of Fourier coefficients for all latitudes.
544-
545-
L (int): Harmonic band-limit.
546-
547-
nside (int): HEALPix Nside resolution parameter.
548-
549-
reality (bool): Whether the signal on the sphere is real. If so,
550-
conjugate symmetry is exploited to reduce computational costs.
551-
552-
Returns:
553-
torch.tensor: HEALPix pixel-space array.
554-
555-
"""
556-
f = torch.zeros(
557-
samples.f_shape(sampling="healpix", nside=nside), dtype=torch.complex128
558-
)
559-
ntheta = ftm.shape[0]
560-
index = 0
561-
for t in range(ntheta):
562-
nphi = samples.nphi_ring(t, nside)
563-
fm_chunk = ftm[t] if nphi == 2 * L else spectral_folding_torch(ftm[t], nphi, L)
564-
if reality and nphi == 2 * L:
565-
f[index : index + nphi] = torch.fft.irfft(
566-
fm_chunk[nphi // 2 :], nphi, norm="forward"
567-
)
568-
else:
569-
f[index : index + nphi] = torch.fft.ifft(
570-
torch.fft.ifftshift(fm_chunk), norm="forward"
571-
)
572-
index += nphi
573-
return f
440+
healpix_ifft_torch = wrap_as_torch_function(healpix_ifft_jax)
574441

575442

576443
def p2phi_rings(t: np.ndarray, nside: int) -> np.ndarray:

0 commit comments

Comments
 (0)