Skip to content

Commit b3e3283

Browse files
committed
Add precompute Wigner benchmarks
1 parent 0c87cd6 commit b3e3283

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

benchmarks/precompute_wigner.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""Benchmarks for precompute Wigner-d transforms."""
2+
3+
import numpy as np
4+
from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip
5+
6+
import s2fft
7+
import s2fft.precompute_transforms
8+
from s2fft.base_transforms import wigner as base_wigner
9+
from s2fft.sampling import s2_samples as samples
10+
11+
L_VALUES = [16, 32, 64, 128, 256]
12+
N_VALUES = [2]
13+
L_LOWER_VALUES = [0]
14+
SAMPLING_VALUES = ["mw"]
15+
METHOD_VALUES = ["numpy", "jax"]
16+
REALITY_VALUES = [True]
17+
18+
def setup_forward(method, L, N, L_lower, sampling, reality):
19+
rng = np.random.default_rng()
20+
flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality)
21+
f = base_wigner.inverse(
22+
flmn,
23+
L,
24+
N,
25+
L_lower=L_lower,
26+
sampling=sampling,
27+
reality=reality,
28+
)
29+
kernel_function = (
30+
s2fft.precompute_transforms.construct.wigner_kernel_jax
31+
if method == "jax"
32+
else s2fft.precompute_transforms.construct.wigner_kernel
33+
)
34+
kernel = kernel_function(
35+
L=L, N=N, reality=reality, sampling=sampling, forward=True
36+
)
37+
return {"f": f, "kernel": kernel}
38+
39+
40+
@benchmark(
41+
setup_forward,
42+
method=METHOD_VALUES,
43+
L=L_VALUES,
44+
N=N_VALUES,
45+
L_lower=L_LOWER_VALUES,
46+
sampling=SAMPLING_VALUES,
47+
reality=REALITY_VALUES,
48+
)
49+
def forward(f, kernel, method, L, N, L_lower, sampling, reality):
50+
flmn = s2fft.precompute_transforms.wigner.forward(
51+
f=f,
52+
L=L,
53+
N=N,
54+
kernel=kernel,
55+
sampling=sampling,
56+
reality=reality,
57+
method=method,
58+
)
59+
if method == "jax":
60+
flmn.block_until_ready()
61+
62+
63+
def setup_inverse(method, L, N, L_lower, sampling, reality):
64+
rng = np.random.default_rng()
65+
flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality)
66+
kernel_function = (
67+
s2fft.precompute_transforms.construct.wigner_kernel_jax
68+
if method == "jax"
69+
else s2fft.precompute_transforms.construct.wigner_kernel
70+
)
71+
kernel = kernel_function(
72+
L=L, N=N, reality=reality, sampling=sampling, forward=False
73+
)
74+
return {"flmn": flmn, "kernel": kernel}
75+
76+
77+
@benchmark(
78+
setup_inverse,
79+
method=METHOD_VALUES,
80+
L=L_VALUES,
81+
N=N_VALUES,
82+
L_lower=L_LOWER_VALUES,
83+
sampling=SAMPLING_VALUES,
84+
reality=REALITY_VALUES,
85+
)
86+
def inverse(flmn, kernel, method, L, N, L_lower, sampling, reality):
87+
f = s2fft.precompute_transforms.wigner.inverse(
88+
flmn=flmn,
89+
L=L,
90+
N=N,
91+
kernel=kernel,
92+
sampling=sampling,
93+
reality=reality,
94+
method=method,
95+
)
96+
if method == "jax":
97+
f.block_until_ready()
98+
99+
100+
if __name__ == "__main__":
101+
results = parse_args_collect_and_run_benchmarks()

0 commit comments

Comments
 (0)