@@ -336,14 +336,14 @@ def forward(
336336 f : np .ndarray ,
337337 L : int ,
338338 spin : int = 0 ,
339- nside : int = None ,
339+ nside : int | None = None ,
340340 sampling : str = "mw" ,
341341 method : str = "numpy" ,
342342 reality : bool = False ,
343- precomps : List = None ,
343+ precomps : List | None = None ,
344344 spmd : bool = False ,
345345 L_lower : int = 0 ,
346- iter : int = 3 ,
346+ iter : int | None = None ,
347347 _ssht_backend : int = 1 ,
348348) -> np .ndarray :
349349 r"""
@@ -379,9 +379,13 @@ def forward(
379379 L_lower (int, optional): Harmonic lower-bound. Transform will only be computed
380380 for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0.
381381
382- iter (int, optional): Number of subiterations for healpy. Note that iterations
383- increase the precision of the forward transform, but reduce the accuracy of
384- the gradient pass. Between 2 and 3 iterations is a good compromise.
382+ iter (int, optional): Number of iterative refinement iterations to use to
383+ improve accuracy of forward transform (as an inverse of inverse transform).
384+ Primarily of use with HEALPix sampling for which there is not a sampling
385+ theorem, and round-tripping through the forward and inverse transforms will
386+ introduce an error. If set to `None`, the default, 3 iterations will be used
387+ if :code:`sampling == "healpix"` and :code`method == "jax_healpy"` and zero
388+ otherwise. Not used for `jax_ssht` method.
385389
386390 _ssht_backend (int, optional, experimental): Whether to default to SSHT core
387391 (set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental
@@ -404,9 +408,10 @@ def forward(
404408 if spin >= 8 and method in ["numpy" , "jax" ]:
405409 raise Warning ("Recursive transform may provide lower precision beyond spin ~ 8" )
406410
411+ if iter is None :
412+ iter = 3 if sampling .lower () == "healpix" and method == "jax_healpy" else 0
407413 if method in {"numpy" , "jax" , "cuda" }:
408- kwargs = {
409- "f" : f ,
414+ common_kwargs = {
410415 "L" : L ,
411416 "spin" : spin ,
412417 "nside" : nside ,
@@ -416,10 +421,23 @@ def forward(
416421 "L_lower" : L_lower ,
417422 }
418423 if method in {"jax" , "cuda" }:
419- kwargs ["spmd" ] = spmd
420- kwargs ["use_healpix_custom_primitive" ] = method == "cuda"
421- forward_function = forward_numpy if method == "numpy" else forward_jax
422- return forward_function (** kwargs )
424+ forward_kwargs = {
425+ ** common_kwargs ,
426+ "spmd" : spmd ,
427+ "use_healpix_custom_primitive" : method == "cuda" ,
428+ }
429+ inverse_kwargs = {** common_kwargs , "method" : "jax" }
430+ forward_function = forward_jax
431+ else :
432+ forward_kwargs = common_kwargs
433+ inverse_kwargs = {** common_kwargs , "method" : "numpy" }
434+ forward_function = forward_numpy
435+ flm = forward_function (f , ** forward_kwargs )
436+ for _ in range (iter ):
437+ f_recov = inverse (flm , ** inverse_kwargs )
438+ f_error = f - f_recov
439+ flm += forward_function (f_error , ** forward_kwargs )
440+ return flm
423441 elif method == "jax_ssht" :
424442 if sampling .lower () == "healpix" :
425443 raise ValueError ("SSHT does not support healpix sampling." )
0 commit comments