Skip to content

Commit 8dbcc97

Browse files
committed
Add iterative refinement option to precompute forward transform
1 parent d5fd078 commit 8dbcc97

File tree

1 file changed

+78
-17
lines changed

1 file changed

+78
-17
lines changed

s2fft/precompute_transforms/spherical.py

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
11
from functools import partial
2+
from typing import Optional
23
from warnings import warn
34

45
import jax.numpy as jnp
56
import numpy as np
67
import torch
78
from jax import jit
89

10+
from s2fft.precompute_transforms import construct
911
from s2fft.sampling import s2_samples as samples
1012
from s2fft.utils import healpix_ffts as hp
11-
from s2fft.utils import resampling, resampling_jax, resampling_torch
13+
from s2fft.utils import (
14+
iterative_refinement,
15+
resampling,
16+
resampling_jax,
17+
resampling_torch,
18+
)
1219

1320

1421
def inverse(
1522
flm: np.ndarray,
1623
L: int,
1724
spin: int = 0,
18-
kernel: np.ndarray = None,
25+
kernel: Optional[np.ndarray] = None,
1926
sampling: str = "mw",
2027
reality: bool = False,
2128
method: str = "jax",
22-
nside: int = None,
29+
nside: Optional[int] = None,
2330
) -> np.ndarray:
2431
r"""
2532
Compute the inverse spherical harmonic transform via precompute.
@@ -62,14 +69,21 @@ def inverse(
6269
+ "Defering to complex transform.",
6370
stacklevel=2,
6471
)
65-
inverse_functions = {
66-
"numpy": inverse_transform,
67-
"jax": inverse_transform_jax,
68-
"torch": inverse_transform_torch,
72+
common_kwargs = {
73+
"L": L,
74+
"sampling": sampling,
75+
"reality": reality,
76+
"spin": spin,
77+
"nside": nside,
6978
}
70-
if method not in inverse_functions:
79+
kernel = (
80+
_kernel_functions[method](forward=False, **common_kwargs)
81+
if kernel is None
82+
else kernel
83+
)
84+
if method not in _inverse_functions:
7185
raise ValueError(f"Method {method} not recognised.")
72-
return inverse_functions[method](flm, kernel, L, sampling, reality, spin, nside)
86+
return _inverse_functions[method](flm, kernel, **common_kwargs)
7387

7488

7589
def inverse_transform(
@@ -290,11 +304,12 @@ def forward(
290304
f: np.ndarray,
291305
L: int,
292306
spin: int = 0,
293-
kernel: np.ndarray = None,
307+
kernel: Optional[np.ndarray] = None,
294308
sampling: str = "mw",
295309
reality: bool = False,
296310
method: str = "jax",
297-
nside: int = None,
311+
nside: Optional[int] = None,
312+
iter: int = 0,
298313
) -> np.ndarray:
299314
r"""
300315
Compute the forward spherical harmonic transform via precompute.
@@ -321,6 +336,12 @@ def forward(
321336
nside (int): HEALPix Nside resolution parameter. Only required
322337
if sampling="healpix".
323338
339+
iter (int, optional): Number of iterative refinement iterations to use to
340+
improve accuracy of forward transform (as an inverse of inverse transform).
341+
Primarily of use with HEALPix sampling for which there is not a sampling
342+
theorem, and round-tripping through the forward and inverse transforms will
343+
introduce an error.
344+
324345
Raises:
325346
ValueError: Transform method not recognised.
326347
@@ -337,14 +358,34 @@ def forward(
337358
+ "Defering to complex transform.",
338359
stacklevel=2,
339360
)
340-
forward_functions = {
341-
"numpy": forward_transform,
342-
"jax": forward_transform_jax,
343-
"torch": forward_transform_torch,
361+
common_kwargs = {
362+
"L": L,
363+
"sampling": sampling,
364+
"reality": reality,
365+
"spin": spin,
366+
"nside": nside,
344367
}
345-
if method not in forward_functions:
368+
kernel = (
369+
_kernel_functions[method](forward=True, **common_kwargs)
370+
if kernel is None
371+
else kernel
372+
)
373+
if method not in _forward_functions:
346374
raise ValueError(f"Method {method} not recognised.")
347-
return forward_functions[method](f, kernel, L, sampling, reality, spin, nside)
375+
if iter == 0:
376+
return _forward_functions[method](f, kernel, **common_kwargs)
377+
else:
378+
inverse_kernel = _kernel_functions[method](forward=False, **common_kwargs)
379+
return iterative_refinement.forward_with_iterative_refinement(
380+
f=f,
381+
n_iter=iter,
382+
forward_function=partial(
383+
_forward_functions[method], kernel=kernel, **common_kwargs
384+
),
385+
backward_function=partial(
386+
_inverse_functions[method], kernel=inverse_kernel, **common_kwargs
387+
),
388+
)
348389

349390

350391
def forward_transform(
@@ -567,3 +608,23 @@ def forward_transform_torch(
567608
)
568609

569610
return flm * (-1) ** spin
611+
612+
613+
_inverse_functions = {
614+
"numpy": inverse_transform,
615+
"jax": inverse_transform_jax,
616+
"torch": inverse_transform_torch,
617+
}
618+
619+
620+
_forward_functions = {
621+
"numpy": forward_transform,
622+
"jax": forward_transform_jax,
623+
"torch": forward_transform_torch,
624+
}
625+
626+
_kernel_functions = {
627+
"numpy": partial(construct.fourier_wigner_kernel, using_torch=False),
628+
"jax": construct.fourier_wigner_kernel_jax,
629+
"torch": partial(construct.fourier_wigner_kernel, using_torch=True),
630+
}

0 commit comments

Comments
 (0)