Skip to content

Commit 972e647

Browse files
committed
Use factored out function in healpy wrapper
1 parent 85645ae commit 972e647

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

s2fft/transforms/c_backend_spherical.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import healpy
24
import jax.numpy as jnp
35
import numpy as np
@@ -8,7 +10,7 @@
810
from jax.interpreters import ad
911

1012
from s2fft.sampling import reindex
11-
from s2fft.utils import quadrature_jax
13+
from s2fft.utils import iterative_refinement, quadrature_jax
1214

1315

1416
@custom_vjp
@@ -427,11 +429,12 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda
427429
Astrophysical Journal 622.2 (2005): 759
428430
429431
"""
430-
flm = healpy_map2alm(f, L, nside)
431-
for _ in range(iter):
432-
f_recov = healpy_alm2map(flm, L, nside)
433-
f_error = f - f_recov
434-
flm += healpy_map2alm(f_error, L, nside)
432+
flm = iterative_refinement.forward_with_iterative_refinement(
433+
f=f,
434+
n_iter=iter,
435+
forward_function=partial(healpy_map2alm, L=L, nside=nside),
436+
backward_function=partial(healpy_alm2map, L=L, nside=nside),
437+
)
435438
return reindex.flm_hp_to_2d_fast(flm, L)
436439

437440

0 commit comments

Comments
 (0)