Skip to content

Commit 268e621

Browse files
committed
Add benchmarks for precompute versions of spherical transforms
1 parent fdf93e3 commit 268e621

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

benchmarks/precompute_spherical.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""Benchmarks for spherical transforms."""
2+
3+
import numpy as np
4+
import pyssht
5+
from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip
6+
7+
import s2fft
8+
import s2fft.precompute_transforms
9+
from s2fft.sampling import s2_samples as samples
10+
11+
L_VALUES = [8, 16, 32, 64, 128, 256]
12+
SPIN_VALUES = [0]
13+
SAMPLING_VALUES = ["mw"]
14+
METHOD_VALUES = ["numpy", "jax"]
15+
REALITY_VALUES = [True]
16+
17+
18+
def setup_forward(method, L, sampling, spin, reality):
19+
if reality and spin != 0:
20+
skip("Reality only valid for scalar fields (spin=0).")
21+
rng = np.random.default_rng()
22+
flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality)
23+
f = pyssht.inverse(
24+
samples.flm_2d_to_1d(flm, L),
25+
L,
26+
Method=sampling.upper(),
27+
Spin=spin,
28+
Reality=reality,
29+
)
30+
kernel_function = (
31+
s2fft.precompute_transforms.construct.spin_spherical_kernel_jax
32+
if method == "jax"
33+
else s2fft.precompute_transforms.construct.spin_spherical_kernel
34+
)
35+
kernel = kernel_function(
36+
L=L, spin=spin, reality=reality, sampling=sampling, forward=True
37+
)
38+
return {"f": f, "kernel": kernel}
39+
40+
41+
@benchmark(
42+
setup_forward,
43+
method=METHOD_VALUES,
44+
L=L_VALUES,
45+
sampling=SAMPLING_VALUES,
46+
spin=SPIN_VALUES,
47+
reality=REALITY_VALUES,
48+
)
49+
def forward(f, kernel, method, L, sampling, spin, reality):
50+
flm = s2fft.precompute_transforms.spherical.forward(
51+
f=f,
52+
L=L,
53+
spin=spin,
54+
kernel=kernel,
55+
sampling=sampling,
56+
reality=reality,
57+
method=method,
58+
)
59+
if method == "jax":
60+
flm.block_until_ready()
61+
62+
63+
def setup_inverse(method, L, sampling, spin, reality):
64+
if reality and spin != 0:
65+
skip("Reality only valid for scalar fields (spin=0).")
66+
rng = np.random.default_rng()
67+
flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality)
68+
kernel_function = (
69+
s2fft.precompute_transforms.construct.spin_spherical_kernel_jax
70+
if method == "jax"
71+
else s2fft.precompute_transforms.construct.spin_spherical_kernel
72+
)
73+
kernel = kernel_function(
74+
L=L, spin=spin, reality=reality, sampling=sampling, forward=False
75+
)
76+
return {"flm": flm, "kernel": kernel}
77+
78+
79+
@benchmark(
80+
setup_inverse,
81+
method=METHOD_VALUES,
82+
L=L_VALUES,
83+
sampling=SAMPLING_VALUES,
84+
spin=SPIN_VALUES,
85+
reality=REALITY_VALUES,
86+
)
87+
def inverse(flm, kernel, method, L, sampling, spin, reality):
88+
f = s2fft.precompute_transforms.spherical.inverse(
89+
flm=flm,
90+
L=L,
91+
spin=spin,
92+
kernel=kernel,
93+
sampling=sampling,
94+
reality=reality,
95+
method=method,
96+
)
97+
if method == "jax":
98+
f.block_until_ready()
99+
100+
101+
if __name__ == "__main__":
102+
results = parse_args_collect_and_run_benchmarks()

0 commit comments

Comments
 (0)