|
3 | 3 | import jax.numpy as jnp |
4 | 4 | import jaxlib.mlir.ir as ir |
5 | 5 | import numpy as np |
6 | | -import torch |
7 | 6 | from jax import jit, vmap |
8 | 7 |
|
9 | 8 | # did not find promote_dtypes_complex outside _src |
|
14 | 13 |
|
15 | 14 | from s2fft.sampling import s2_samples as samples |
16 | 15 | from s2fft.utils.jax_primitive import register_primitive |
| 16 | +from s2fft.utils.torch_wrapper import wrap_as_torch_function |
17 | 17 |
|
18 | 18 |
|
19 | 19 | def spectral_folding(fm: np.ndarray, nphi: int, L: int) -> np.ndarray: |
@@ -79,39 +79,7 @@ def spectral_folding_jax(fm: jnp.ndarray, nphi: int, L: int) -> jnp.ndarray: |
79 | 79 | ) |
80 | 80 |
|
81 | 81 |
|
82 | | -def spectral_folding_torch(fm: torch.tensor, nphi: int, L: int) -> torch.tensor: |
83 | | - """ |
84 | | - Folds higher frequency Fourier coefficients back onto lower frequency |
85 | | - coefficients, i.e. aliasing high frequencies. Torch specific implementation of |
86 | | - :func:`~spectral_folding`. |
87 | | -
|
88 | | - Args: |
89 | | - fm (torch.tensor): Slice of Fourier coefficients corresponding to ring at latitute t. |
90 | | -
|
91 | | - nphi (int): Total number of pixel space phi samples for latitude t. |
92 | | -
|
93 | | - L (int): Harmonic band-limit. |
94 | | -
|
95 | | - Returns: |
96 | | - torch.tensor: Lower resolution set of aliased Fourier coefficients. |
97 | | -
|
98 | | - """ |
99 | | - slice_start = L - nphi // 2 |
100 | | - slice_stop = slice_start + nphi |
101 | | - ftm_slice = fm[slice_start:slice_stop] |
102 | | - |
103 | | - ftm_slice = ftm_slice.put_( |
104 | | - -torch.arange(1, L - nphi // 2 + 1) % nphi, |
105 | | - fm[slice_start - torch.arange(1, L - nphi // 2 + 1)], |
106 | | - accumulate=True, |
107 | | - ) |
108 | | - ftm_slice = ftm_slice.put_( |
109 | | - torch.arange(L - nphi // 2) % nphi, |
110 | | - fm[slice_stop + torch.arange(L - nphi // 2)], |
111 | | - accumulate=True, |
112 | | - ) |
113 | | - |
114 | | - return ftm_slice |
| 82 | +spectral_folding_torch = wrap_as_torch_function(spectral_folding_jax) |
115 | 83 |
|
116 | 84 |
|
117 | 85 | def spectral_periodic_extension(fm: np.ndarray, nphi: int, L: int) -> np.ndarray: |
@@ -174,29 +142,9 @@ def spectral_periodic_extension_jax(fm: jnp.ndarray, L: int) -> jnp.ndarray: |
174 | 142 | ) |
175 | 143 |
|
176 | 144 |
|
177 | | -def spectral_periodic_extension_torch(fm: torch.tensor, L: int) -> torch.tensor: |
178 | | - """ |
179 | | - Extends lower frequency Fourier coefficients onto higher frequency |
180 | | - coefficients, i.e. imposed periodicity in Fourier space. Based on |
181 | | - :func:`~spectral_periodic_extension`. |
182 | | -
|
183 | | - Args: |
184 | | - fm (torch.tensor): Slice of Fourier coefficients corresponding to ring at latitute t. |
185 | | -
|
186 | | - L (int): Harmonic band-limit. |
187 | | -
|
188 | | - Returns: |
189 | | - torch.tensor: Higher resolution set of periodic Fourier coefficients. |
190 | | -
|
191 | | - """ |
192 | | - nphi = fm.shape[0] |
193 | | - return torch.concatenate( |
194 | | - ( |
195 | | - fm[-torch.arange(L - nphi // 2, 0, -1) % nphi], |
196 | | - fm, |
197 | | - fm[torch.arange(L - (nphi + 1) // 2) % nphi], |
198 | | - ) |
199 | | - ) |
| 145 | +spectral_periodic_extension_torch = wrap_as_torch_function( |
| 146 | + spectral_periodic_extension_jax |
| 147 | +) |
200 | 148 |
|
201 | 149 |
|
202 | 150 | def healpix_fft( |
@@ -350,49 +298,7 @@ def f_chunks_to_ftm_rows(f_chunks, nphi): |
350 | 298 | ) |
351 | 299 |
|
352 | 300 |
|
353 | | -def healpix_fft_torch( |
354 | | - f: torch.tensor, L: int, nside: int, reality: bool |
355 | | -) -> torch.tensor: |
356 | | - """ |
357 | | - Computes the Forward Fast Fourier Transform with spectral back-projection |
358 | | - in the polar regions to manually enforce Fourier periodicity. Torch specific |
359 | | - implementation of :func:`~healpix_fft_numpy`. |
360 | | -
|
361 | | - Args: |
362 | | - f (torch.tensor): HEALPix pixel-space array. |
363 | | -
|
364 | | - L (int): Harmonic band-limit. |
365 | | -
|
366 | | - nside (int): HEALPix Nside resolution parameter. |
367 | | -
|
368 | | - reality (bool): Whether the signal on the sphere is real. If so, |
369 | | - conjugate symmetry is exploited to reduce computational costs. |
370 | | -
|
371 | | - Returns: |
372 | | - torch.tensor: Array of Fourier coefficients for all latitudes. |
373 | | -
|
374 | | - """ |
375 | | - index = 0 |
376 | | - ftm = torch.zeros(samples.ftm_shape(L, "healpix", nside), dtype=torch.complex128) |
377 | | - ntheta = ftm.shape[0] |
378 | | - for t in range(ntheta): |
379 | | - nphi = samples.nphi_ring(t, nside) |
380 | | - if reality and nphi == 2 * L: |
381 | | - fm_chunk = torch.zeros(nphi, dtype=torch.complex128) |
382 | | - fm_chunk[nphi // 2 :] = torch.fft.rfft( |
383 | | - torch.real(f[index : index + nphi]), norm="backward" |
384 | | - )[:-1] |
385 | | - else: |
386 | | - fm_chunk = torch.fft.fftshift( |
387 | | - torch.fft.fft(f[index : index + nphi], norm="backward") |
388 | | - ) |
389 | | - ftm[t] = ( |
390 | | - fm_chunk |
391 | | - if nphi == 2 * L |
392 | | - else spectral_periodic_extension_torch(fm_chunk, L) |
393 | | - ) |
394 | | - index += nphi |
395 | | - return ftm |
| 301 | +healpix_fft_torch = wrap_as_torch_function(healpix_fft_jax) |
396 | 302 |
|
397 | 303 |
|
398 | 304 | def healpix_ifft( |
@@ -531,46 +437,7 @@ def ftm_rows_to_f_chunks(ftm_rows, nphi): |
531 | 437 | ) |
532 | 438 |
|
533 | 439 |
|
534 | | -def healpix_ifft_torch( |
535 | | - ftm: torch.tensor, L: int, nside: int, reality: bool |
536 | | -) -> torch.tensor: |
537 | | - """ |
538 | | - Computes the Inverse Fast Fourier Transform with spectral folding in the polar |
539 | | - regions to mitigate aliasing. Torch specific implementation of |
540 | | - :func:`~healpix_ifft_numpy`. |
541 | | -
|
542 | | - Args: |
543 | | - ftm (torch.tensor): Array of Fourier coefficients for all latitudes. |
544 | | -
|
545 | | - L (int): Harmonic band-limit. |
546 | | -
|
547 | | - nside (int): HEALPix Nside resolution parameter. |
548 | | -
|
549 | | - reality (bool): Whether the signal on the sphere is real. If so, |
550 | | - conjugate symmetry is exploited to reduce computational costs. |
551 | | -
|
552 | | - Returns: |
553 | | - torch.tensor: HEALPix pixel-space array. |
554 | | -
|
555 | | - """ |
556 | | - f = torch.zeros( |
557 | | - samples.f_shape(sampling="healpix", nside=nside), dtype=torch.complex128 |
558 | | - ) |
559 | | - ntheta = ftm.shape[0] |
560 | | - index = 0 |
561 | | - for t in range(ntheta): |
562 | | - nphi = samples.nphi_ring(t, nside) |
563 | | - fm_chunk = ftm[t] if nphi == 2 * L else spectral_folding_torch(ftm[t], nphi, L) |
564 | | - if reality and nphi == 2 * L: |
565 | | - f[index : index + nphi] = torch.fft.irfft( |
566 | | - fm_chunk[nphi // 2 :], nphi, norm="forward" |
567 | | - ) |
568 | | - else: |
569 | | - f[index : index + nphi] = torch.fft.ifft( |
570 | | - torch.fft.ifftshift(fm_chunk), norm="forward" |
571 | | - ) |
572 | | - index += nphi |
573 | | - return f |
| 440 | +healpix_ifft_torch = wrap_as_torch_function(healpix_ifft_jax) |
574 | 441 |
|
575 | 442 |
|
576 | 443 | def p2phi_rings(t: np.ndarray, nside: int) -> np.ndarray: |
|
0 commit comments