11from functools import partial
2+ from typing import Optional
23from warnings import warn
34
45import jax .numpy as jnp
56import numpy as np
67import torch
78from jax import jit
89
10+ from s2fft .precompute_transforms import construct
911from s2fft .sampling import s2_samples as samples
1012from s2fft .utils import healpix_ffts as hp
11- from s2fft .utils import resampling , resampling_jax , resampling_torch
13+ from s2fft .utils import (
14+ iterative_refinement ,
15+ resampling ,
16+ resampling_jax ,
17+ resampling_torch ,
18+ )
1219
1320
1421def inverse (
1522 flm : np .ndarray ,
1623 L : int ,
1724 spin : int = 0 ,
18- kernel : np .ndarray = None ,
25+ kernel : Optional [ np .ndarray ] = None ,
1926 sampling : str = "mw" ,
2027 reality : bool = False ,
2128 method : str = "jax" ,
22- nside : int = None ,
29+ nside : Optional [ int ] = None ,
2330) -> np .ndarray :
2431 r"""
2532 Compute the inverse spherical harmonic transform via precompute.
@@ -62,14 +69,21 @@ def inverse(
6269 + "Defering to complex transform." ,
6370 stacklevel = 2 ,
6471 )
65- inverse_functions = {
66- "numpy" : inverse_transform ,
67- "jax" : inverse_transform_jax ,
68- "torch" : inverse_transform_torch ,
72+ common_kwargs = {
73+ "L" : L ,
74+ "sampling" : sampling ,
75+ "reality" : reality ,
76+ "spin" : spin ,
77+ "nside" : nside ,
6978 }
70- if method not in inverse_functions :
79+ kernel = (
80+ _kernel_functions [method ](forward = False , ** common_kwargs )
81+ if kernel is None
82+ else kernel
83+ )
84+ if method not in _inverse_functions :
7185 raise ValueError (f"Method { method } not recognised." )
72- return inverse_functions [method ](flm , kernel , L , sampling , reality , spin , nside )
86+ return _inverse_functions [method ](flm , kernel , ** common_kwargs )
7387
7488
7589def inverse_transform (
@@ -290,11 +304,12 @@ def forward(
290304 f : np .ndarray ,
291305 L : int ,
292306 spin : int = 0 ,
293- kernel : np .ndarray = None ,
307+ kernel : Optional [ np .ndarray ] = None ,
294308 sampling : str = "mw" ,
295309 reality : bool = False ,
296310 method : str = "jax" ,
297- nside : int = None ,
311+ nside : Optional [int ] = None ,
312+ iter : int = 0 ,
298313) -> np .ndarray :
299314 r"""
300315 Compute the forward spherical harmonic transform via precompute.
@@ -321,6 +336,12 @@ def forward(
321336 nside (int): HEALPix Nside resolution parameter. Only required
322337 if sampling="healpix".
323338
339+ iter (int, optional): Number of iterative refinement iterations to use to
340+ improve accuracy of forward transform (as an inverse of inverse transform).
341+ Primarily of use with HEALPix sampling for which there is not a sampling
342+ theorem, and round-tripping through the forward and inverse transforms will
343+ introduce an error.
344+
324345 Raises:
325346 ValueError: Transform method not recognised.
326347
@@ -337,14 +358,34 @@ def forward(
337358 + "Defering to complex transform." ,
338359 stacklevel = 2 ,
339360 )
340- forward_functions = {
341- "numpy" : forward_transform ,
342- "jax" : forward_transform_jax ,
343- "torch" : forward_transform_torch ,
361+ common_kwargs = {
362+ "L" : L ,
363+ "sampling" : sampling ,
364+ "reality" : reality ,
365+ "spin" : spin ,
366+ "nside" : nside ,
344367 }
345- if method not in forward_functions :
368+ kernel = (
369+ _kernel_functions [method ](forward = True , ** common_kwargs )
370+ if kernel is None
371+ else kernel
372+ )
373+ if method not in _forward_functions :
346374 raise ValueError (f"Method { method } not recognised." )
347- return forward_functions [method ](f , kernel , L , sampling , reality , spin , nside )
375+ if iter == 0 :
376+ return _forward_functions [method ](f , kernel , ** common_kwargs )
377+ else :
378+ inverse_kernel = _kernel_functions [method ](forward = False , ** common_kwargs )
379+ return iterative_refinement .forward_with_iterative_refinement (
380+ f = f ,
381+ n_iter = iter ,
382+ forward_function = partial (
383+ _forward_functions [method ], kernel = kernel , ** common_kwargs
384+ ),
385+ backward_function = partial (
386+ _inverse_functions [method ], kernel = inverse_kernel , ** common_kwargs
387+ ),
388+ )
348389
349390
350391def forward_transform (
@@ -567,3 +608,23 @@ def forward_transform_torch(
567608 )
568609
569610 return flm * (- 1 ) ** spin
611+
612+
613+ _inverse_functions = {
614+ "numpy" : inverse_transform ,
615+ "jax" : inverse_transform_jax ,
616+ "torch" : inverse_transform_torch ,
617+ }
618+
619+
620+ _forward_functions = {
621+ "numpy" : forward_transform ,
622+ "jax" : forward_transform_jax ,
623+ "torch" : forward_transform_torch ,
624+ }
625+
626+ _kernel_functions = {
627+ "numpy" : partial (construct .fourier_wigner_kernel , using_torch = False ),
628+ "jax" : construct .fourier_wigner_kernel_jax ,
629+ "torch" : partial (construct .fourier_wigner_kernel , using_torch = True ),
630+ }
0 commit comments