|
| 1 | +from jax import config |
| 2 | + |
| 3 | +config.update("jax_enable_x64", True) |
1 | 4 | import pytest |
2 | 5 | import numpy as np |
3 | 6 | from s2fft.sampling import s2_samples as samples |
4 | | -from s2fft.utils import quadrature |
| 7 | +from s2fft.utils import quadrature, quadrature_jax, quadrature_torch |
5 | 8 | from s2fft.base_transforms import spherical |
6 | 9 |
|
7 | 10 |
|
8 | 11 | @pytest.mark.parametrize("L", [5, 6]) |
9 | 12 | @pytest.mark.parametrize("sampling", ["mw", "mwss"]) |
10 | | -def test_quadrature_mw_weights(flm_generator, L: int, sampling: str): |
| 13 | +@pytest.mark.parametrize("method", ["numpy", "jax", "torch"]) |
| 14 | +def test_quadrature_mw_weights(flm_generator, L: int, sampling: str, method: str): |
11 | 15 | spin = 0 |
12 | 16 |
|
13 | | - q = quadrature.quad_weights(L, sampling, spin) |
| 17 | + if method.lower() == "numpy": |
| 18 | + q = quadrature.quad_weights(L, sampling, spin) |
| 19 | + elif method.lower() == "jax": |
| 20 | + q = quadrature_jax.quad_weights(L, sampling) |
| 21 | + elif method.lower() == "torch": |
| 22 | + q = quadrature_torch.quad_weights(L, sampling).numpy() |
14 | 23 |
|
15 | 24 | flm = flm_generator(L, spin, reality=False) |
16 | 25 |
|
|
0 commit comments