Skip to content

Commit 034ecbd

Browse files
committed
Add torch wrappers for Wigner OTF transforms
1 parent cd3988d commit 034ecbd

File tree

2 files changed

+27
-23
lines changed

2 files changed

+27
-23
lines changed

s2fft/transforms/wigner.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import s2fft
99
from s2fft.sampling import so3_samples as samples
1010
from s2fft.transforms import c_backend_spherical as c_sph
11+
from s2fft.utils import torch_wrapper
1112

1213

1314
def inverse(
@@ -84,7 +85,7 @@ def inverse(
8485
if method not in _inverse_functions:
8586
raise ValueError(f"Method {method} not recognised.")
8687

87-
if N >= 8 and method in ("numpy", "jax"):
88+
if N >= 8 and method in ("numpy", "jax", "torch"):
8889
raise Warning("Recursive transform may provide lower precision beyond N ~ 8")
8990

9091
inverse_kwargs = {
@@ -96,7 +97,7 @@ def inverse(
9697
"reality": reality,
9798
}
9899

99-
if method in ("jax", "numpy"):
100+
if method in ("jax", "numpy", "torch"):
100101
inverse_kwargs.update(nside=nside, precomps=precomps)
101102

102103
if method == "jax_ssht":
@@ -288,6 +289,9 @@ def func(flm, spin, p0, p1, p2, p3, p4):
288289
return f
289290

290291

292+
inverse_torch = torch_wrapper.wrap_as_torch_function(inverse_jax)
293+
294+
291295
def inverse_jax_ssht(
292296
flmn: jnp.ndarray,
293297
L: int,
@@ -413,7 +417,7 @@ def forward(
413417
if method not in _inverse_functions:
414418
raise ValueError(f"Method {method} not recognised.")
415419

416-
if N >= 8 and method in ("numpy", "jax"):
420+
if N >= 8 and method in ("numpy", "jax", "torch"):
417421
raise Warning("Recursive transform may provide lower precision beyond N ~ 8")
418422

419423
forward_kwargs = {
@@ -425,7 +429,7 @@ def forward(
425429
"reality": reality,
426430
}
427431

428-
if method in ("jax", "numpy"):
432+
if method in ("jax", "numpy", "torch"):
429433
forward_kwargs.update(nside=nside, precomps=precomps)
430434

431435
if method == "jax_ssht":
@@ -642,6 +646,9 @@ def func(fba, spin, p0, p1, p2, p3, p4):
642646
return flmn
643647

644648

649+
forward_torch = torch_wrapper.wrap_as_torch_function(forward_jax)
650+
651+
645652
def forward_jax_ssht(
646653
f: jnp.ndarray,
647654
L: int,
@@ -829,10 +836,12 @@ def _fban_to_f(fban: jnp.ndarray, L: int, N: int, reality: bool = False) -> jnp.
829836
"numpy": inverse_numpy,
830837
"jax": inverse_jax,
831838
"jax_ssht": inverse_jax_ssht,
839+
"torch": inverse_torch,
832840
}
833841

834842
_forward_functions = {
835843
"numpy": forward_numpy,
836844
"jax": forward_jax,
837845
"jax_ssht": forward_jax_ssht,
846+
"torch": forward_torch,
838847
}

tests/test_wigner_transform.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,16 @@
1616
N_to_test = [2]
1717
L_lower_to_test = [0, 2]
1818
sampling_to_test = ["mw", "mwss", "dh", "gl"]
19-
method_to_test = ["numpy", "jax"]
19+
method_to_test = ["numpy", "jax", "torch"]
2020
reality_to_test = [False, True]
2121

22+
_generate_precomputes_functions = {
23+
"jax": generate_precomputes_wigner_jax,
24+
"numpy": generate_precomputes_wigner,
25+
# torch method wraps jax so use jax to generate precomputess
26+
"torch": generate_precomputes_wigner_jax,
27+
}
28+
2229

2330
@pytest.mark.parametrize("L", L_to_test)
2431
@pytest.mark.parametrize("N", N_to_test)
@@ -38,15 +45,9 @@ def test_inverse_wigner_transform(
3845
):
3946
flmn = flmn_generator(L=L, N=N, L_lower=L_lower, reality=reality)
4047
f_check = base_wigner.inverse(flmn, L, N, L_lower, sampling, reality)
41-
42-
if method.lower() == "jax":
43-
precomps = generate_precomputes_wigner_jax(
44-
L, N, sampling, None, False, reality, L_lower
45-
)
46-
else:
47-
precomps = generate_precomputes_wigner(
48-
L, N, sampling, None, False, reality, L_lower
49-
)
48+
precomps = _generate_precomputes_functions[method](
49+
L, N, sampling, None, False, reality, L_lower
50+
)
5051
f = wigner.inverse(flmn, L, N, None, sampling, method, reality, precomps, L_lower)
5152
np.testing.assert_allclose(f, f_check, atol=1e-14)
5253

@@ -69,15 +70,9 @@ def test_forward_wigner_transform(
6970
):
7071
flmn = flmn_generator(L=L, N=N, L_lower=L_lower, reality=reality)
7172
f = base_wigner.inverse(flmn, L, N, L_lower, sampling, reality)
72-
73-
if method.lower() == "jax":
74-
precomps = generate_precomputes_wigner_jax(
75-
L, N, sampling, None, True, reality, L_lower
76-
)
77-
else:
78-
precomps = generate_precomputes_wigner(
79-
L, N, sampling, None, True, reality, L_lower
80-
)
73+
precomps = _generate_precomputes_functions[method](
74+
L, N, sampling, None, True, reality, L_lower
75+
)
8176
flmn_check = wigner.forward(
8277
f, L, N, None, sampling, method, reality, precomps, L_lower
8378
)

0 commit comments

Comments
 (0)