Skip to content

Commit 6082ebb

Browse files
committed
Iterative refinement support for jax and numpy forward spherical
1 parent 27fb8b4 commit 6082ebb

File tree

1 file changed

+30
-12
lines changed

1 file changed

+30
-12
lines changed

s2fft/transforms/spherical.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -336,14 +336,14 @@ def forward(
336336
f: np.ndarray,
337337
L: int,
338338
spin: int = 0,
339-
nside: int = None,
339+
nside: int | None = None,
340340
sampling: str = "mw",
341341
method: str = "numpy",
342342
reality: bool = False,
343-
precomps: List = None,
343+
precomps: List | None = None,
344344
spmd: bool = False,
345345
L_lower: int = 0,
346-
iter: int = 3,
346+
iter: int | None = None,
347347
_ssht_backend: int = 1,
348348
) -> np.ndarray:
349349
r"""
@@ -379,9 +379,13 @@ def forward(
379379
L_lower (int, optional): Harmonic lower-bound. Transform will only be computed
380380
for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0.
381381
382-
iter (int, optional): Number of subiterations for healpy. Note that iterations
383-
increase the precision of the forward transform, but reduce the accuracy of
384-
the gradient pass. Between 2 and 3 iterations is a good compromise.
382+
iter (int, optional): Number of iterative refinement iterations to use to
383+
improve accuracy of forward transform (as an inverse of inverse transform).
384+
Primarily of use with HEALPix sampling for which there is not a sampling
385+
theorem, and round-tripping through the forward and inverse transforms will
386+
introduce an error. If set to `None`, the default, 3 iterations will be used
387+
if :code:`sampling == "healpix"` and :code`method == "jax_healpy"` and zero
388+
otherwise. Not used for `jax_ssht` method.
385389
386390
_ssht_backend (int, optional, experimental): Whether to default to SSHT core
387391
(set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental
@@ -404,9 +408,10 @@ def forward(
404408
if spin >= 8 and method in ["numpy", "jax"]:
405409
raise Warning("Recursive transform may provide lower precision beyond spin ~ 8")
406410

411+
if iter is None:
412+
iter = 3 if sampling.lower() == "healpix" and method == "jax_healpy" else 0
407413
if method in {"numpy", "jax", "cuda"}:
408-
kwargs = {
409-
"f": f,
414+
common_kwargs = {
410415
"L": L,
411416
"spin": spin,
412417
"nside": nside,
@@ -416,10 +421,23 @@ def forward(
416421
"L_lower": L_lower,
417422
}
418423
if method in {"jax", "cuda"}:
419-
kwargs["spmd"] = spmd
420-
kwargs["use_healpix_custom_primitive"] = method == "cuda"
421-
forward_function = forward_numpy if method == "numpy" else forward_jax
422-
return forward_function(**kwargs)
424+
forward_kwargs = {
425+
**common_kwargs,
426+
"spmd": spmd,
427+
"use_healpix_custom_primitive": method == "cuda",
428+
}
429+
inverse_kwargs = {**common_kwargs, "method": "jax"}
430+
forward_function = forward_jax
431+
else:
432+
forward_kwargs = common_kwargs
433+
inverse_kwargs = {**common_kwargs, "method": "numpy"}
434+
forward_function = forward_numpy
435+
flm = forward_function(f, **forward_kwargs)
436+
for _ in range(iter):
437+
f_recov = inverse(flm, **inverse_kwargs)
438+
f_error = f - f_recov
439+
flm += forward_function(f_error, **forward_kwargs)
440+
return flm
423441
elif method == "jax_ssht":
424442
if sampling.lower() == "healpix":
425443
raise ValueError("SSHT does not support healpix sampling.")

0 commit comments

Comments
 (0)