Skip to content

Commit 85645ae

Browse files
committed
Factor out iterative refinement function
1 parent 5c47269 commit 85645ae

File tree

3 files changed

+54
-6
lines changed

3 files changed

+54
-6
lines changed

s2fft/transforms/spherical.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from s2fft.transforms import otf_recursions as otf
1111
from s2fft.utils import healpix_ffts as hp
1212
from s2fft.utils import (
13+
iterative_refinement,
1314
quadrature,
1415
quadrature_jax,
1516
resampling,
@@ -435,12 +436,12 @@ def forward(
435436
forward_kwargs = common_kwargs
436437
inverse_kwargs = {**common_kwargs, "method": "numpy"}
437438
forward_function = forward_numpy
438-
flm = forward_function(f, **forward_kwargs)
439-
for _ in range(iter):
440-
f_recov = inverse(flm, **inverse_kwargs)
441-
f_error = f - f_recov
442-
flm += forward_function(f_error, **forward_kwargs)
443-
return flm
439+
return iterative_refinement.forward_with_iterative_refinement(
440+
f=f,
441+
n_iter=iter,
442+
forward_function=partial(forward_function, **forward_kwargs),
443+
backward_function=partial(inverse, **inverse_kwargs),
444+
)
444445
elif method == "jax_ssht":
445446
if sampling.lower() == "healpix":
446447
raise ValueError("SSHT does not support healpix sampling.")

s2fft/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from . import (
22
healpix_ffts,
3+
iterative_refinement,
34
jax_primitive,
45
quadrature,
56
quadrature_jax,
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Iterative scheme for improving accuracy of linear transforms."""
2+
3+
from collections.abc import Callable
4+
from typing import TypeVar
5+
6+
T = TypeVar("T")
7+
8+
9+
def forward_with_iterative_refinement(
10+
f: T,
11+
n_iter: int,
12+
forward_function: Callable[[T], T],
13+
backward_function: Callable[[T], T],
14+
) -> T:
15+
"""
16+
Apply forward transform with iterative refinement to improve accuracy.
17+
18+
`Iterative refinement <https://en.wikipedia.org/wiki/Iterative_refinement>`_ is a
19+
general approach for improving the accuracy of numerial solutions to linear systems.
20+
In the context of spherical harmonic transforms, given a forward transform which is
21+
an _approximate_ inverse to a corresponding backward ('inverse') transform,
22+
iterative refinement allows defining an iterative forward transform which is a more
23+
accurate
24+
25+
Args:
26+
f: Array argument to forward transform (signal on sphere) to compute iteratively
27+
refined forward transform at.
28+
29+
n_iter: Number of refinement iterations to use, non-negative.
30+
31+
forward_function: Function computing forward transform (approximate inverse of
32+
backward transform).
33+
34+
backward_function: Function computing backward ('inverse') transform.
35+
36+
Returns:
37+
Array output from iteratively refined forward transform (spherical harmonic
38+
coefficients).
39+
40+
"""
41+
flm = forward_function(f)
42+
for _ in range(n_iter):
43+
f_recov = backward_function(flm)
44+
f_error = f - f_recov
45+
flm += forward_function(f_error)
46+
return flm

0 commit comments

Comments
 (0)