Skip to content

Commit f4e9458

Browse files
committed
Expose recursion and mode parameters for precompute benchmarks
1 parent 54db6d2 commit f4e9458

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

benchmarks/precompute_spherical.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
SAMPLING_VALUES = ["mw"]
1414
METHOD_VALUES = ["numpy", "jax"]
1515
REALITY_VALUES = [True]
16+
RECURSION_VALUES = ["auto"]
1617

1718

18-
def setup_forward(method, L, sampling, spin, reality):
19+
def setup_forward(method, L, sampling, spin, reality, recursion):
1920
if reality and spin != 0:
2021
skip("Reality only valid for scalar fields (spin=0).")
2122
rng = np.random.default_rng()
@@ -33,7 +34,12 @@ def setup_forward(method, L, sampling, spin, reality):
3334
else s2fft.precompute_transforms.construct.spin_spherical_kernel
3435
)
3536
kernel = kernel_function(
36-
L=L, spin=spin, reality=reality, sampling=sampling, forward=True
37+
L=L,
38+
spin=spin,
39+
reality=reality,
40+
sampling=sampling,
41+
forward=True,
42+
recursion=recursion,
3743
)
3844
return {"f": f, "kernel": kernel}
3945

@@ -45,8 +51,9 @@ def setup_forward(method, L, sampling, spin, reality):
4551
sampling=SAMPLING_VALUES,
4652
spin=SPIN_VALUES,
4753
reality=REALITY_VALUES,
54+
recursion=RECURSION_VALUES,
4855
)
49-
def forward(f, kernel, method, L, sampling, spin, reality):
56+
def forward(f, kernel, method, L, sampling, spin, reality, recursion):
5057
flm = s2fft.precompute_transforms.spherical.forward(
5158
f=f,
5259
L=L,
@@ -60,7 +67,7 @@ def forward(f, kernel, method, L, sampling, spin, reality):
6067
flm.block_until_ready()
6168

6269

63-
def setup_inverse(method, L, sampling, spin, reality):
70+
def setup_inverse(method, L, sampling, spin, reality, recursion):
6471
if reality and spin != 0:
6572
skip("Reality only valid for scalar fields (spin=0).")
6673
rng = np.random.default_rng()
@@ -71,7 +78,12 @@ def setup_inverse(method, L, sampling, spin, reality):
7178
else s2fft.precompute_transforms.construct.spin_spherical_kernel
7279
)
7380
kernel = kernel_function(
74-
L=L, spin=spin, reality=reality, sampling=sampling, forward=False
81+
L=L,
82+
spin=spin,
83+
reality=reality,
84+
sampling=sampling,
85+
forward=False,
86+
recursion=recursion,
7587
)
7688
return {"flm": flm, "kernel": kernel}
7789

@@ -83,8 +95,9 @@ def setup_inverse(method, L, sampling, spin, reality):
8395
sampling=SAMPLING_VALUES,
8496
spin=SPIN_VALUES,
8597
reality=REALITY_VALUES,
98+
recursion=RECURSION_VALUES,
8699
)
87-
def inverse(flm, kernel, method, L, sampling, spin, reality):
100+
def inverse(flm, kernel, method, L, sampling, spin, reality, recursion):
88101
f = s2fft.precompute_transforms.spherical.inverse(
89102
flm=flm,
90103
L=L,

benchmarks/precompute_wigner.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
SAMPLING_VALUES = ["mw"]
1515
METHOD_VALUES = ["numpy", "jax"]
1616
REALITY_VALUES = [True]
17+
MODE_VALUES = ["auto"]
1718

18-
def setup_forward(method, L, N, L_lower, sampling, reality):
19+
20+
def setup_forward(method, L, N, L_lower, sampling, reality, mode):
1921
rng = np.random.default_rng()
2022
flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality)
2123
f = base_wigner.inverse(
@@ -32,7 +34,7 @@ def setup_forward(method, L, N, L_lower, sampling, reality):
3234
else s2fft.precompute_transforms.construct.wigner_kernel
3335
)
3436
kernel = kernel_function(
35-
L=L, N=N, reality=reality, sampling=sampling, forward=True
37+
L=L, N=N, reality=reality, sampling=sampling, forward=True, mode=mode
3638
)
3739
return {"f": f, "kernel": kernel}
3840

@@ -45,8 +47,9 @@ def setup_forward(method, L, N, L_lower, sampling, reality):
4547
L_lower=L_LOWER_VALUES,
4648
sampling=SAMPLING_VALUES,
4749
reality=REALITY_VALUES,
50+
mode=MODE_VALUES,
4851
)
49-
def forward(f, kernel, method, L, N, L_lower, sampling, reality):
52+
def forward(f, kernel, method, L, N, L_lower, sampling, reality, mode):
5053
flmn = s2fft.precompute_transforms.wigner.forward(
5154
f=f,
5255
L=L,
@@ -60,7 +63,7 @@ def forward(f, kernel, method, L, N, L_lower, sampling, reality):
6063
flmn.block_until_ready()
6164

6265

63-
def setup_inverse(method, L, N, L_lower, sampling, reality):
66+
def setup_inverse(method, L, N, L_lower, sampling, reality, mode):
6467
rng = np.random.default_rng()
6568
flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality)
6669
kernel_function = (
@@ -69,7 +72,7 @@ def setup_inverse(method, L, N, L_lower, sampling, reality):
6972
else s2fft.precompute_transforms.construct.wigner_kernel
7073
)
7174
kernel = kernel_function(
72-
L=L, N=N, reality=reality, sampling=sampling, forward=False
75+
L=L, N=N, reality=reality, sampling=sampling, forward=False, mode=mode
7376
)
7477
return {"flmn": flmn, "kernel": kernel}
7578

@@ -82,8 +85,9 @@ def setup_inverse(method, L, N, L_lower, sampling, reality):
8285
L_lower=L_LOWER_VALUES,
8386
sampling=SAMPLING_VALUES,
8487
reality=REALITY_VALUES,
88+
mode=MODE_VALUES,
8589
)
86-
def inverse(flmn, kernel, method, L, N, L_lower, sampling, reality):
90+
def inverse(flmn, kernel, method, L, N, L_lower, sampling, reality, mode):
8791
f = s2fft.precompute_transforms.wigner.inverse(
8892
flmn=flmn,
8993
L=L,

0 commit comments

Comments
 (0)