Skip to content

Commit 8c4227d

Browse files
committed
reduce total number of test cases, reactivate CI
1 parent 01e811f commit 8c4227d

12 files changed

+43
-63
lines changed

.github/workflows/tests.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ jobs:
2525
with:
2626
python-version: ${{ matrix.python-version }}
2727

28-
# - name: Install dependencies
29-
# run: |
30-
# python -m pip install --upgrade pip
31-
# pip install -r requirements/requirements-tests.txt
32-
# pip install -r requirements/requirements-core.txt
33-
# pip install .
28+
- name: Install dependencies
29+
run: |
30+
python -m pip install --upgrade pip
31+
pip install -r requirements/requirements-tests.txt
32+
pip install -r requirements/requirements-core.txt
33+
pip install .
3434
35-
# - name: Run tests
36-
# run: |
37-
# pytest --cov-report term --cov=s2wav --cov-config=.coveragerc
38-
# codecov --token 298dc7ee-bb9f-4221-b31f-3576cc6cb702
35+
- name: Run tests
36+
run: |
37+
pytest --cov-report term --cov=s2wav --cov-config=.coveragerc
38+
codecov --token 298dc7ee-bb9f-4221-b31f-3576cc6cb702

tests/test_samples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from s2fft.sampling import s2_samples as samples
66

77

8-
nside_to_test = [32, 64, 128]
8+
nside_to_test = [16, 32]
99

1010

1111
@pytest.mark.parametrize("L", [15, 16])

tests/test_spherical_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
import healpy as hp
77

88

9-
L_to_test = [6, 7, 8]
10-
L_lower_to_test = [0, 1, 2]
11-
spin_to_test = [-2, -1, 0, 1, 2]
12-
nside_to_test = [2, 4, 8]
9+
L_to_test = [6, 7]
10+
L_lower_to_test = [0, 2]
11+
spin_to_test = [-2, 0, 1]
12+
nside_to_test = [4, 5]
1313
L_to_nside_ratio = [2, 3]
1414
sampling_to_test = ["mw", "mwss", "dh"]
1515
method_to_test = ["direct", "sov", "sov_fft", "sov_fft_vectorized"]

tests/test_spherical_custom_grads.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,25 @@
22

33
config.update("jax_enable_x64", True)
44
import pytest
5-
import numpy as np
65
import jax.numpy as jnp
76
from jax.test_util import check_grads
87

9-
from s2fft.sampling import s2_samples as samples
108
from s2fft.transforms import spherical
119
from s2fft.recursions.price_mcewen import generate_precomputes_jax
1210

1311
L_to_test = [16]
14-
L_lower_to_test = [0, 2]
12+
L_lower_to_test = [2]
1513
spin_to_test = [-2, 0, 1]
16-
nside_to_test = [8, 10]
14+
nside_to_test = [8]
1715
sampling_to_test = ["mw", "mwss", "dh"]
1816
reality_to_test = [False, True]
19-
multiple_gpus = [False]
2017

2118

2219
@pytest.mark.parametrize("L", L_to_test)
2320
@pytest.mark.parametrize("L_lower", L_lower_to_test)
2421
@pytest.mark.parametrize("spin", spin_to_test)
2522
@pytest.mark.parametrize("sampling", sampling_to_test)
2623
@pytest.mark.parametrize("reality", reality_to_test)
27-
@pytest.mark.parametrize("spmd", multiple_gpus)
2824
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
2925
def test_inverse_custom_gradients(
3026
flm_generator,
@@ -33,7 +29,6 @@ def test_inverse_custom_gradients(
3329
spin: int,
3430
sampling: str,
3531
reality: bool,
36-
spmd: bool,
3732
):
3833
if reality and spin != 0:
3934
pytest.skip("Reality only valid for scalar fields (spin=0).")
@@ -61,7 +56,6 @@ def func(flm):
6156
reality=reality,
6257
precomps=precomps,
6358
sampling=sampling,
64-
spmd=spmd,
6559
)
6660
return jnp.sum(jnp.abs(f - f_target) ** 2)
6761

@@ -73,7 +67,6 @@ def func(flm):
7367
@pytest.mark.parametrize("spin", spin_to_test)
7468
@pytest.mark.parametrize("sampling", sampling_to_test)
7569
@pytest.mark.parametrize("reality", reality_to_test)
76-
@pytest.mark.parametrize("spmd", multiple_gpus)
7770
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
7871
def test_forward_custom_gradients(
7972
flm_generator,
@@ -82,7 +75,6 @@ def test_forward_custom_gradients(
8275
spin: int,
8376
sampling: str,
8477
reality: bool,
85-
spmd: bool,
8678
):
8779
if reality and spin != 0:
8880
pytest.skip("Reality only valid for scalar fields (spin=0).")
@@ -110,7 +102,6 @@ def func(f):
110102
reality=reality,
111103
precomps=precomps,
112104
sampling=sampling,
113-
spmd=spmd,
114105
)
115106
return jnp.sum(jnp.abs(flm - flm_target) ** 2)
116107

@@ -122,7 +113,6 @@ def func(f):
122113
@pytest.mark.parametrize("spin", spin_to_test)
123114
@pytest.mark.parametrize("sampling", sampling_to_test)
124115
@pytest.mark.parametrize("reality", reality_to_test)
125-
@pytest.mark.parametrize("spmd", multiple_gpus)
126116
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
127117
def test_healpix_inverse_custom_gradients(
128118
flm_generator,
@@ -131,7 +121,6 @@ def test_healpix_inverse_custom_gradients(
131121
spin: int,
132122
sampling: str,
133123
reality: bool,
134-
spmd: bool,
135124
):
136125
sampling = "healpix"
137126
L = 2 * nside
@@ -164,7 +153,6 @@ def func(flm):
164153
reality=reality,
165154
precomps=precomps,
166155
sampling=sampling,
167-
spmd=spmd,
168156
)
169157
return jnp.sum(jnp.abs(f - f_target) ** 2)
170158

@@ -176,7 +164,6 @@ def func(flm):
176164
@pytest.mark.parametrize("spin", spin_to_test)
177165
@pytest.mark.parametrize("sampling", sampling_to_test)
178166
@pytest.mark.parametrize("reality", reality_to_test)
179-
@pytest.mark.parametrize("spmd", multiple_gpus)
180167
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
181168
def test_healpix_forward_custom_gradients(
182169
flm_generator,
@@ -185,7 +172,6 @@ def test_healpix_forward_custom_gradients(
185172
spin: int,
186173
sampling: str,
187174
reality: bool,
188-
spmd: bool,
189175
):
190176
sampling = "healpix"
191177
L = 2 * nside
@@ -218,7 +204,6 @@ def func(f):
218204
reality=reality,
219205
precomps=precomps,
220206
sampling=sampling,
221-
spmd=spmd,
222207
)
223208
return jnp.sum(jnp.abs(flm - flm_target) ** 2)
224209

tests/test_spherical_precompute.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from s2fft.precompute_transforms.construct import spin_spherical_kernel
55
from s2fft.base_transforms import spherical as base
66

7-
L_to_test = [6, 7, 8]
8-
spin_to_test = [-2, -1, 0, 1, 2]
9-
nside_to_test = [2, 4, 8]
7+
L_to_test = [6, 7]
8+
spin_to_test = [-2, 0, 1]
9+
nside_to_test = [4, 5]
1010
L_to_nside_ratio = [2, 3]
1111
sampling_to_test = ["mw", "mwss", "dh"]
1212
reality_to_test = [True, False]

tests/test_spherical_transform.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from s2fft.transforms import spherical
1111
from s2fft.recursions.price_mcewen import generate_precomputes
1212

13-
L_to_test = [6, 7, 8]
14-
L_lower_to_test = [0, 1, 2]
15-
spin_to_test = [-2, -1, 0, 1, 2]
16-
nside_to_test = [2, 4, 8]
13+
L_to_test = [6, 7]
14+
L_lower_to_test = [0, 2]
15+
spin_to_test = [-2, 0, 1]
16+
nside_to_test = [4, 5]
1717
sampling_to_test = ["mw", "mwss", "dh"]
1818
method_to_test = ["numpy", "jax"]
1919
reality_to_test = [False, True]

tests/test_wigner_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from s2fft.base_transforms import wigner
66

77

8-
L_to_test = [8, 16]
9-
N_to_test = [2, 4, 6]
10-
L_lower_to_test = [0, 2, 4]
8+
L_to_test = [6, 7]
9+
N_to_test = [2, 3]
10+
L_lower_to_test = [0, 2]
1111
sampling_schemes_so3 = ["mw", "mwss"]
1212
sampling_schemes = ["mw", "mwss", "dh"]
1313
reality_to_test = [False, True]

tests/test_wigner_custom_grads.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,18 @@
88
from s2fft.transforms import wigner
99
from s2fft.recursions.price_mcewen import generate_precomputes_wigner_jax
1010

11-
L_to_test = [16]
12-
N_to_test = [2]
13-
L_lower_to_test = [0, 2]
11+
L_to_test = [6]
12+
N_to_test = [3]
13+
L_lower_to_test = [1]
1414
sampling_to_test = ["mw", "mwss", "dh"]
15-
reality_to_test = [False]
16-
multiple_gpus = [False]
15+
reality_to_test = [False, True]
1716

1817

1918
@pytest.mark.parametrize("L", L_to_test)
2019
@pytest.mark.parametrize("N", N_to_test)
2120
@pytest.mark.parametrize("L_lower", L_lower_to_test)
2221
@pytest.mark.parametrize("sampling", sampling_to_test)
2322
@pytest.mark.parametrize("reality", reality_to_test)
24-
@pytest.mark.parametrize("spmd", multiple_gpus)
2523
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
2624
def test_inverse_wigner_custom_gradients(
2725
flmn_generator,
@@ -30,7 +28,6 @@ def test_inverse_wigner_custom_gradients(
3028
L_lower: int,
3129
sampling: str,
3230
reality: bool,
33-
spmd: bool,
3431
):
3532
precomps = generate_precomputes_wigner_jax(
3633
L, N, sampling, None, False, reality, L_lower
@@ -39,12 +36,12 @@ def test_inverse_wigner_custom_gradients(
3936
flmn = flmn_generator(L=L, N=N, L_lower=L_lower, reality=reality)
4037
flmn_target = flmn_generator(L=L, N=N, L_lower=L_lower, reality=reality)
4138
f_target = wigner.inverse_jax(
42-
flmn_target, L, N, None, sampling, reality, precomps, spmd, L_lower
39+
flmn_target, L, N, None, sampling, reality, precomps, False, L_lower
4340
)
4441

4542
def func(flmn):
4643
f = wigner.inverse_jax(
47-
flmn, L, N, None, sampling, reality, precomps, spmd, L_lower
44+
flmn, L, N, None, sampling, reality, precomps, False, L_lower
4845
)
4946
return jnp.sum(jnp.abs(f - f_target) ** 2)
5047

@@ -56,7 +53,6 @@ def func(flmn):
5653
@pytest.mark.parametrize("L_lower", L_lower_to_test)
5754
@pytest.mark.parametrize("sampling", sampling_to_test)
5855
@pytest.mark.parametrize("reality", reality_to_test)
59-
@pytest.mark.parametrize("spmd", multiple_gpus)
6056
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
6157
def test_forward_wigner_custom_gradients(
6258
flmn_generator,
@@ -65,7 +61,6 @@ def test_forward_wigner_custom_gradients(
6561
L_lower: int,
6662
sampling: str,
6763
reality: bool,
68-
spmd: bool,
6964
):
7065
precomps = generate_precomputes_wigner_jax(
7166
L, N, sampling, None, True, reality, L_lower
@@ -74,12 +69,12 @@ def test_forward_wigner_custom_gradients(
7469
flmn_target = flmn_generator(L=L, N=N, L_lower=L_lower, reality=reality)
7570
flmn = flmn_generator(L=L, N=N, L_lower=L_lower, reality=reality)
7671
f = wigner.inverse_jax(
77-
flmn, L, N, None, sampling, reality, None, spmd, L_lower
72+
flmn, L, N, None, sampling, reality, None, False, L_lower
7873
)
7974

8075
def func(f):
8176
flmn = wigner.forward_jax(
82-
f, L, N, None, sampling, reality, precomps, spmd, L_lower
77+
f, L, N, None, sampling, reality, precomps, False, L_lower
8378
)
8479
return jnp.sum(jnp.abs(flmn - flmn_target) ** 2)
8580

tests/test_wigner_precompute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from s2fft.base_transforms import wigner as base
77

88
L_to_test = [8, 10]
9-
N_to_test = [2, 4]
9+
N_to_test = [2, 3]
1010
nside_to_test = [4, 6]
11-
L_to_nside_ratio = [2, 3]
11+
L_to_nside_ratio = [2]
1212
reality_to_test = [False, True]
1313
sampling_schemes = ["mw", "mwss", "dh"]
1414
methods_to_test = ["numpy", "jax"]

tests/test_wigner_recursions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from s2fft.sampling import s2_samples as samples
99
import pyssht as ssht
1010

11-
L_to_test = [8, 16]
12-
spin_to_test = np.arange(-2, 2)
11+
L_to_test = [6, 7]
12+
spin_to_test = [-2, 0, 1]
1313
sampling_schemes = ["mw", "mwss", "dh", "healpix"]
1414

1515

0 commit comments

Comments
 (0)