Skip to content

Commit a9e7c0c

Browse files
authored
Merge pull request #241 from astro-informatics/mmg/iterative-refinement
Iterative refinement support for JAX and NumPy forward (spherical) transform implementations
2 parents 8f6e4d5 + 23fb7af commit a9e7c0c

File tree

11 files changed

+273
-76
lines changed

11 files changed

+273
-76
lines changed

s2fft/base_transforms/spherical.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
from functools import partial
12
from warnings import warn
23

34
import numpy as np
45

56
from s2fft import recursions
67
from s2fft.sampling import s2_samples as samples
78
from s2fft.utils import healpix_ffts as hp
8-
from s2fft.utils import quadrature, resampling
9+
from s2fft.utils import iterative_refinement, quadrature, resampling
910

1011

1112
def inverse(
@@ -138,6 +139,7 @@ def forward(
138139
nside: int = None,
139140
reality: bool = False,
140141
L_lower: int = 0,
142+
iter: int = 0,
141143
) -> np.ndarray:
142144
r"""
143145
Compute forward spherical harmonic transform.
@@ -164,20 +166,34 @@ def forward(
164166
L_lower (int, optional): Harmonic lower-bound. Transform will only be computed
165167
for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0.
166168
169+
iter (int, optional): Number of iterative refinement iterations to use to
170+
improve accuracy of forward transform (as an inverse of inverse transform).
171+
Primarily of use with HEALPix sampling for which there is not a sampling
172+
theorem, and round-tripping through the forward and inverse transforms will
173+
introduce an error.
174+
167175
Returns:
168176
np.ndarray: Spherical harmonic coefficients.
169177
170178
"""
171-
return _forward(
172-
f,
173-
L,
174-
spin,
175-
sampling,
176-
nside=nside,
177-
method="sov_fft_vectorized",
178-
reality=reality,
179-
L_lower=L_lower,
180-
)
179+
common_kwargs = {
180+
"L": L,
181+
"spin": spin,
182+
"sampling": sampling,
183+
"nside": nside,
184+
"method": "sov_fft_vectorized",
185+
"reality": reality,
186+
"L_lower": L_lower,
187+
}
188+
if iter == 0:
189+
return _forward(f, **common_kwargs)
190+
else:
191+
return iterative_refinement.forward_with_iterative_refinement(
192+
f,
193+
n_iter=iter,
194+
forward_function=partial(_forward, **common_kwargs),
195+
backward_function=partial(_inverse, **common_kwargs),
196+
)
181197

182198

183199
def _forward(

s2fft/precompute_transforms/spherical.py

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
11
from functools import partial
2+
from typing import Optional
23
from warnings import warn
34

45
import jax.numpy as jnp
56
import numpy as np
67
import torch
78
from jax import jit
89

10+
from s2fft.precompute_transforms import construct
911
from s2fft.sampling import s2_samples as samples
1012
from s2fft.utils import healpix_ffts as hp
11-
from s2fft.utils import resampling, resampling_jax, resampling_torch
13+
from s2fft.utils import (
14+
iterative_refinement,
15+
resampling,
16+
resampling_jax,
17+
resampling_torch,
18+
)
1219

1320

1421
def inverse(
1522
flm: np.ndarray,
1623
L: int,
1724
spin: int = 0,
18-
kernel: np.ndarray = None,
25+
kernel: Optional[np.ndarray] = None,
1926
sampling: str = "mw",
2027
reality: bool = False,
2128
method: str = "jax",
22-
nside: int = None,
29+
nside: Optional[int] = None,
2330
) -> np.ndarray:
2431
r"""
2532
Compute the inverse spherical harmonic transform via precompute.
@@ -55,21 +62,28 @@ def inverse(
5562
np.ndarray: Pixel-space coefficients with shape.
5663
5764
"""
65+
if method not in _inverse_functions:
66+
raise ValueError(f"Method {method} not recognised.")
5867
if reality and spin != 0:
5968
reality = False
6069
warn(
6170
"Reality acceleration only supports spin 0 fields. "
6271
+ "Defering to complex transform.",
6372
stacklevel=2,
6473
)
65-
if method == "numpy":
66-
return inverse_transform(flm, kernel, L, sampling, reality, spin, nside)
67-
elif method == "jax":
68-
return inverse_transform_jax(flm, kernel, L, sampling, reality, spin, nside)
69-
elif method == "torch":
70-
return inverse_transform_torch(flm, kernel, L, sampling, reality, spin, nside)
71-
else:
72-
raise ValueError(f"Method {method} not recognised.")
74+
common_kwargs = {
75+
"L": L,
76+
"sampling": sampling,
77+
"reality": reality,
78+
"spin": spin,
79+
"nside": nside,
80+
}
81+
kernel = (
82+
_kernel_functions[method](forward=False, **common_kwargs)
83+
if kernel is None
84+
else kernel
85+
)
86+
return _inverse_functions[method](flm, kernel, **common_kwargs)
7387

7488

7589
def inverse_transform(
@@ -290,11 +304,12 @@ def forward(
290304
f: np.ndarray,
291305
L: int,
292306
spin: int = 0,
293-
kernel: np.ndarray = None,
307+
kernel: Optional[np.ndarray] = None,
294308
sampling: str = "mw",
295309
reality: bool = False,
296310
method: str = "jax",
297-
nside: int = None,
311+
nside: Optional[int] = None,
312+
iter: int = 0,
298313
) -> np.ndarray:
299314
r"""
300315
Compute the forward spherical harmonic transform via precompute.
@@ -321,6 +336,12 @@ def forward(
321336
nside (int): HEALPix Nside resolution parameter. Only required
322337
if sampling="healpix".
323338
339+
iter (int, optional): Number of iterative refinement iterations to use to
340+
improve accuracy of forward transform (as an inverse of inverse transform).
341+
Primarily of use with HEALPix sampling for which there is not a sampling
342+
theorem, and round-tripping through the forward and inverse transforms will
343+
introduce an error.
344+
324345
Raises:
325346
ValueError: Transform method not recognised.
326347
@@ -330,21 +351,41 @@ def forward(
330351
np.ndarray: Spherical harmonic coefficients.
331352
332353
"""
354+
if method not in _forward_functions:
355+
raise ValueError(f"Method {method} not recognised.")
333356
if reality and spin != 0:
334357
reality = False
335358
warn(
336359
"Reality acceleration only supports spin 0 fields. "
337360
+ "Defering to complex transform.",
338361
stacklevel=2,
339362
)
340-
if method == "numpy":
341-
return forward_transform(f, kernel, L, sampling, reality, spin, nside)
342-
elif method == "jax":
343-
return forward_transform_jax(f, kernel, L, sampling, reality, spin, nside)
344-
elif method == "torch":
345-
return forward_transform_torch(f, kernel, L, sampling, reality, spin, nside)
363+
common_kwargs = {
364+
"L": L,
365+
"sampling": sampling,
366+
"reality": reality,
367+
"spin": spin,
368+
"nside": nside,
369+
}
370+
kernel = (
371+
_kernel_functions[method](forward=True, **common_kwargs)
372+
if kernel is None
373+
else kernel
374+
)
375+
if iter == 0:
376+
return _forward_functions[method](f, kernel, **common_kwargs)
346377
else:
347-
raise ValueError(f"Method {method} not recognised.")
378+
inverse_kernel = _kernel_functions[method](forward=False, **common_kwargs)
379+
return iterative_refinement.forward_with_iterative_refinement(
380+
f=f,
381+
n_iter=iter,
382+
forward_function=partial(
383+
_forward_functions[method], kernel=kernel, **common_kwargs
384+
),
385+
backward_function=partial(
386+
_inverse_functions[method], kernel=inverse_kernel, **common_kwargs
387+
),
388+
)
348389

349390

350391
def forward_transform(
@@ -567,3 +608,23 @@ def forward_transform_torch(
567608
)
568609

569610
return flm * (-1) ** spin
611+
612+
613+
_inverse_functions = {
614+
"numpy": inverse_transform,
615+
"jax": inverse_transform_jax,
616+
"torch": inverse_transform_torch,
617+
}
618+
619+
620+
_forward_functions = {
621+
"numpy": forward_transform,
622+
"jax": forward_transform_jax,
623+
"torch": forward_transform_torch,
624+
}
625+
626+
_kernel_functions = {
627+
"numpy": partial(construct.spin_spherical_kernel, using_torch=False),
628+
"jax": construct.spin_spherical_kernel_jax,
629+
"torch": partial(construct.spin_spherical_kernel, using_torch=True),
630+
}

s2fft/transforms/c_backend_spherical.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import healpy
24
import jax.numpy as jnp
35
import numpy as np
@@ -8,7 +10,7 @@
810
from jax.interpreters import ad
911

1012
from s2fft.sampling import reindex
11-
from s2fft.utils import quadrature_jax
13+
from s2fft.utils import iterative_refinement, quadrature_jax
1214

1315

1416
@custom_vjp
@@ -427,11 +429,12 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda
427429
Astrophysical Journal 622.2 (2005): 759
428430
429431
"""
430-
flm = healpy_map2alm(f, L, nside)
431-
for _ in range(iter):
432-
f_recov = healpy_alm2map(flm, L, nside)
433-
f_error = f - f_recov
434-
flm += healpy_map2alm(f_error, L, nside)
432+
flm = iterative_refinement.forward_with_iterative_refinement(
433+
f=f,
434+
n_iter=iter,
435+
forward_function=partial(healpy_map2alm, L=L, nside=nside),
436+
backward_function=partial(healpy_alm2map, L=L, nside=nside),
437+
)
435438
return reindex.flm_hp_to_2d_fast(flm, L)
436439

437440

s2fft/transforms/otf_recursions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def inverse_latitudinal_step(
8383
precomps = generate_precomputes(L, -mm, sampling, nside, L_lower)
8484
lrenorm, vsign, cpi, cp2, indices = precomps
8585

86+
# Create copy to prevent in-place updates propagating to caller
87+
lrenorm = lrenorm.copy()
88+
8689
for i in range(2):
8790
if not (reality and i == 0):
8891
m_offset = 1 if sampling in ["mwss", "healpix"] and i == 0 else 0
@@ -490,6 +493,9 @@ def forward_latitudinal_step(
490493
precomps = generate_precomputes(L, -mm, sampling, nside, True, L_lower)
491494
lrenorm, vsign, cpi, cp2, indices = precomps
492495

496+
# Create copy to prevent in-place updates propagating to caller
497+
lrenorm = lrenorm.copy()
498+
493499
for i in range(2):
494500
if not (reality and i == 0):
495501
m_offset = 1 if sampling in ["mwss", "healpix"] and i == 0 else 0

0 commit comments

Comments
 (0)