11from functools import partial
2+ from typing import Optional
23from warnings import warn
34
45import jax .numpy as jnp
56import numpy as np
67import torch
78from jax import jit
89
10+ from s2fft .precompute_transforms import construct
911from s2fft .sampling import s2_samples as samples
1012from s2fft .utils import healpix_ffts as hp
11- from s2fft .utils import resampling , resampling_jax , resampling_torch
13+ from s2fft .utils import (
14+ iterative_refinement ,
15+ resampling ,
16+ resampling_jax ,
17+ resampling_torch ,
18+ )
1219
1320
1421def inverse (
1522 flm : np .ndarray ,
1623 L : int ,
1724 spin : int = 0 ,
18- kernel : np .ndarray = None ,
25+ kernel : Optional [ np .ndarray ] = None ,
1926 sampling : str = "mw" ,
2027 reality : bool = False ,
2128 method : str = "jax" ,
22- nside : int = None ,
29+ nside : Optional [ int ] = None ,
2330) -> np .ndarray :
2431 r"""
2532 Compute the inverse spherical harmonic transform via precompute.
@@ -55,21 +62,28 @@ def inverse(
5562 np.ndarray: Pixel-space coefficients with shape.
5663
5764 """
65+ if method not in _inverse_functions :
66+ raise ValueError (f"Method { method } not recognised." )
5867 if reality and spin != 0 :
5968 reality = False
6069 warn (
6170 "Reality acceleration only supports spin 0 fields. "
6271 + "Defering to complex transform." ,
6372 stacklevel = 2 ,
6473 )
65- if method == "numpy" :
66- return inverse_transform (flm , kernel , L , sampling , reality , spin , nside )
67- elif method == "jax" :
68- return inverse_transform_jax (flm , kernel , L , sampling , reality , spin , nside )
69- elif method == "torch" :
70- return inverse_transform_torch (flm , kernel , L , sampling , reality , spin , nside )
71- else :
72- raise ValueError (f"Method { method } not recognised." )
74+ common_kwargs = {
75+ "L" : L ,
76+ "sampling" : sampling ,
77+ "reality" : reality ,
78+ "spin" : spin ,
79+ "nside" : nside ,
80+ }
81+ kernel = (
82+ _kernel_functions [method ](forward = False , ** common_kwargs )
83+ if kernel is None
84+ else kernel
85+ )
86+ return _inverse_functions [method ](flm , kernel , ** common_kwargs )
7387
7488
7589def inverse_transform (
@@ -290,11 +304,12 @@ def forward(
290304 f : np .ndarray ,
291305 L : int ,
292306 spin : int = 0 ,
293- kernel : np .ndarray = None ,
307+ kernel : Optional [ np .ndarray ] = None ,
294308 sampling : str = "mw" ,
295309 reality : bool = False ,
296310 method : str = "jax" ,
297- nside : int = None ,
311+ nside : Optional [int ] = None ,
312+ iter : int = 0 ,
298313) -> np .ndarray :
299314 r"""
300315 Compute the forward spherical harmonic transform via precompute.
@@ -321,6 +336,12 @@ def forward(
321336 nside (int): HEALPix Nside resolution parameter. Only required
322337 if sampling="healpix".
323338
339+ iter (int, optional): Number of iterative refinement iterations to use to
340+ improve accuracy of forward transform (as an inverse of inverse transform).
341+ Primarily of use with HEALPix sampling for which there is not a sampling
342+ theorem, and round-tripping through the forward and inverse transforms will
343+ introduce an error.
344+
324345 Raises:
325346 ValueError: Transform method not recognised.
326347
@@ -330,21 +351,41 @@ def forward(
330351 np.ndarray: Spherical harmonic coefficients.
331352
332353 """
354+ if method not in _forward_functions :
355+ raise ValueError (f"Method { method } not recognised." )
333356 if reality and spin != 0 :
334357 reality = False
335358 warn (
336359 "Reality acceleration only supports spin 0 fields. "
337360 + "Defering to complex transform." ,
338361 stacklevel = 2 ,
339362 )
340- if method == "numpy" :
341- return forward_transform (f , kernel , L , sampling , reality , spin , nside )
342- elif method == "jax" :
343- return forward_transform_jax (f , kernel , L , sampling , reality , spin , nside )
344- elif method == "torch" :
345- return forward_transform_torch (f , kernel , L , sampling , reality , spin , nside )
363+ common_kwargs = {
364+ "L" : L ,
365+ "sampling" : sampling ,
366+ "reality" : reality ,
367+ "spin" : spin ,
368+ "nside" : nside ,
369+ }
370+ kernel = (
371+ _kernel_functions [method ](forward = True , ** common_kwargs )
372+ if kernel is None
373+ else kernel
374+ )
375+ if iter == 0 :
376+ return _forward_functions [method ](f , kernel , ** common_kwargs )
346377 else :
347- raise ValueError (f"Method { method } not recognised." )
378+ inverse_kernel = _kernel_functions [method ](forward = False , ** common_kwargs )
379+ return iterative_refinement .forward_with_iterative_refinement (
380+ f = f ,
381+ n_iter = iter ,
382+ forward_function = partial (
383+ _forward_functions [method ], kernel = kernel , ** common_kwargs
384+ ),
385+ backward_function = partial (
386+ _inverse_functions [method ], kernel = inverse_kernel , ** common_kwargs
387+ ),
388+ )
348389
349390
350391def forward_transform (
@@ -567,3 +608,23 @@ def forward_transform_torch(
567608 )
568609
569610 return flm * (- 1 ) ** spin
611+
612+
613+ _inverse_functions = {
614+ "numpy" : inverse_transform ,
615+ "jax" : inverse_transform_jax ,
616+ "torch" : inverse_transform_torch ,
617+ }
618+
619+
620+ _forward_functions = {
621+ "numpy" : forward_transform ,
622+ "jax" : forward_transform_jax ,
623+ "torch" : forward_transform_torch ,
624+ }
625+
626+ _kernel_functions = {
627+ "numpy" : partial (construct .spin_spherical_kernel , using_torch = False ),
628+ "jax" : construct .spin_spherical_kernel_jax ,
629+ "torch" : partial (construct .spin_spherical_kernel , using_torch = True ),
630+ }
0 commit comments