Skip to content

Commit 942cf9e

Browse files
committed
Fix Wigner benchmarks
1 parent 742fdad commit 942cf9e

File tree

1 file changed

+11
-20
lines changed

1 file changed

+11
-20
lines changed

benchmarks/wigner.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,9 @@
1616
SAMPLING_VALUES = ["mw"]
1717
METHOD_VALUES = ["numpy", "jax"]
1818
REALITY_VALUES = [True]
19-
SPMD_VALUES = [False]
2019

2120

22-
def setup_forward(method, L, L_lower, N, sampling, reality, spmd):
23-
if spmd and method != "jax":
24-
skip("GPU distribution only valid for JAX.")
21+
def setup_forward(method, L, L_lower, N, sampling, reality):
2522
rng = np.random.default_rng()
2623
flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality)
2724
f = base_wigner.inverse(
@@ -51,27 +48,23 @@ def setup_forward(method, L, L_lower, N, sampling, reality, spmd):
5148
N=N_VALUES,
5249
sampling=SAMPLING_VALUES,
5350
reality=REALITY_VALUES,
54-
spmd=SPMD_VALUES,
5551
)
56-
def forward(f, precomps, method, L, L_lower, N, sampling, reality, spmd):
52+
def forward(f, precomps, method, L, L_lower, N, sampling, reality):
5753
flmn = s2fft.transforms.wigner.forward(
5854
f=f,
5955
L=L,
60-
L_lower=L_lower,
6156
N=N,
62-
precomps=precomps,
6357
sampling=sampling,
64-
reality=reality,
6558
method=method,
66-
spmd=spmd,
59+
reality=reality,
60+
precomps=precomps,
61+
L_lower=L_lower,
6762
)
6863
if method == "jax":
6964
flmn.block_until_ready()
7065

7166

72-
def setup_inverse(method, L, L_lower, N, sampling, reality, spmd):
73-
if spmd and method != "jax":
74-
skip("GPU distribution only valid for JAX.")
67+
def setup_inverse(method, L, L_lower, N, sampling, reality):
7568
rng = np.random.default_rng()
7669
flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality)
7770
generate_precomputes = (
@@ -93,19 +86,17 @@ def setup_inverse(method, L, L_lower, N, sampling, reality, spmd):
9386
N=N_VALUES,
9487
sampling=SAMPLING_VALUES,
9588
reality=REALITY_VALUES,
96-
spmd=SPMD_VALUES,
9789
)
98-
def inverse(flmn, precomps, method, L, L_lower, N, sampling, reality, spmd):
99-
f = s2fft.transforms.spherical.inverse(
100-
flm=flmn,
90+
def inverse(flmn, precomps, method, L, L_lower, N, sampling, reality):
91+
f = s2fft.transforms.wigner.inverse(
92+
flmn=flmn,
10193
L=L,
102-
L_lower=L_lower,
10394
N=N,
104-
precomps=precomps,
10595
sampling=sampling,
10696
reality=reality,
10797
method=method,
108-
spmd=spmd,
98+
precomps=precomps,
99+
L_lower=L_lower,
109100
)
110101
if method == "jax":
111102
f.block_until_ready()

0 commit comments

Comments
 (0)