Skip to content

Commit 69b4a05

Browse files
committed
explicitly add jax/torch quadrature tests for codecov
1 parent a813bf3 commit 69b4a05

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

s2fft/utils/quadrature_torch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def quad_weights_mwss_theta_only(L: int) -> torch.tensor:
195195

196196
wr = torch.real(torch.fft.fft(torch.fft.ifftshift(w), norm="backward")) / (2 * L)
197197
q = wr[: L + 1]
198-
q[1:L] += wr[-1:L:-1]
198+
q[1:L] += torch.flip(wr, dims=[0])[: L - 1]
199199

200200
return q
201201

@@ -221,7 +221,7 @@ def quad_weights_mw_theta_only(L: int) -> torch.tensor:
221221
2 * L - 1
222222
)
223223
q = wr[:L]
224-
q[: L - 1] += wr[-1 : L - 1 : -1]
224+
q[: L - 1] += torch.flip(wr, dims=[0])[: L - 1]
225225

226226
return q
227227

tests/test_quadrature.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
1+
from jax import config
2+
3+
config.update("jax_enable_x64", True)
14
import pytest
25
import numpy as np
36
from s2fft.sampling import s2_samples as samples
4-
from s2fft.utils import quadrature
7+
from s2fft.utils import quadrature, quadrature_jax, quadrature_torch
58
from s2fft.base_transforms import spherical
69

710

811
@pytest.mark.parametrize("L", [5, 6])
912
@pytest.mark.parametrize("sampling", ["mw", "mwss"])
10-
def test_quadrature_mw_weights(flm_generator, L: int, sampling: str):
13+
@pytest.mark.parametrize("method", ["numpy", "jax", "torch"])
14+
def test_quadrature_mw_weights(flm_generator, L: int, sampling: str, method: str):
1115
spin = 0
1216

13-
q = quadrature.quad_weights(L, sampling, spin)
17+
if method.lower() == "numpy":
18+
q = quadrature.quad_weights(L, sampling, spin)
19+
elif method.lower() == "jax":
20+
q = quadrature_jax.quad_weights(L, sampling)
21+
elif method.lower() == "torch":
22+
q = quadrature_torch.quad_weights(L, sampling).numpy()
1423

1524
flm = flm_generator(L, spin, reality=False)
1625

0 commit comments

Comments
 (0)