Skip to content

Commit fdf93e3

Browse files
committed
Use JAX function for generating precomputes
1 parent 7cf80bb commit fdf93e3

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

benchmarks/spherical.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from benchmarking import benchmark, parse_args_collect_and_run_benchmarks, skip
66

77
import s2fft
8-
from s2fft.recursions.price_mcewen import generate_precomputes
8+
from s2fft.recursions.price_mcewen import generate_precomputes_jax
99
from s2fft.sampling import s2_samples as samples
1010

1111
L_VALUES = [8, 16, 32, 64, 128, 256]
@@ -17,6 +17,10 @@
1717
SPMD_VALUES = [False]
1818

1919

20+
def _jax_arrays_to_numpy(precomps):
21+
return [np.asarray(p) for p in precomps]
22+
23+
2024
def setup_forward(method, L, L_lower, sampling, spin, reality, spmd):
2125
if reality and spin != 0:
2226
skip("Reality only valid for scalar fields (spin=0).")
@@ -31,7 +35,11 @@ def setup_forward(method, L, L_lower, sampling, spin, reality, spmd):
3135
Spin=spin,
3236
Reality=reality,
3337
)
34-
precomps = generate_precomputes(L, spin, sampling, forward=True, L_lower=L_lower)
38+
precomps = generate_precomputes_jax(
39+
L, spin, sampling, forward=True, L_lower=L_lower
40+
)
41+
if method == "numpy":
42+
precomps = _jax_arrays_to_numpy(precomps)
3543
return {"f": f, "precomps": precomps}
3644

3745

@@ -71,7 +79,11 @@ def setup_inverse(method, L, L_lower, sampling, spin, reality, spmd):
7179
skip("GPU distribution only valid for JAX.")
7280
rng = np.random.default_rng()
7381
flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality)
74-
precomps = generate_precomputes(L, spin, sampling, forward=False, L_lower=L_lower)
82+
precomps = generate_precomputes_jax(
83+
L, spin, sampling, forward=False, L_lower=L_lower
84+
)
85+
if method == "numpy":
86+
precomps = _jax_arrays_to_numpy(precomps)
7587
return {"flm": flm, "precomps": precomps}
7688

7789

0 commit comments

Comments
 (0)