55from benchmarking import benchmark , parse_args_collect_and_run_benchmarks , skip
66
77import s2fft
8- from s2fft .recursions .price_mcewen import generate_precomputes
8+ from s2fft .recursions .price_mcewen import generate_precomputes_jax
99from s2fft .sampling import s2_samples as samples
1010
1111L_VALUES = [8 , 16 , 32 , 64 , 128 , 256 ]
1717SPMD_VALUES = [False ]
1818
1919
20+ def _jax_arrays_to_numpy (precomps ):
21+ return [np .asarray (p ) for p in precomps ]
22+
23+
2024def setup_forward (method , L , L_lower , sampling , spin , reality , spmd ):
2125 if reality and spin != 0 :
2226 skip ("Reality only valid for scalar fields (spin=0)." )
@@ -31,7 +35,11 @@ def setup_forward(method, L, L_lower, sampling, spin, reality, spmd):
3135 Spin = spin ,
3236 Reality = reality ,
3337 )
34- precomps = generate_precomputes (L , spin , sampling , forward = True , L_lower = L_lower )
38+ precomps = generate_precomputes_jax (
39+ L , spin , sampling , forward = True , L_lower = L_lower
40+ )
41+ if method == "numpy" :
42+ precomps = _jax_arrays_to_numpy (precomps )
3543 return {"f" : f , "precomps" : precomps }
3644
3745
@@ -71,7 +79,11 @@ def setup_inverse(method, L, L_lower, sampling, spin, reality, spmd):
7179 skip ("GPU distribution only valid for JAX." )
7280 rng = np .random .default_rng ()
7381 flm = s2fft .utils .signal_generator .generate_flm (rng , L , spin = spin , reality = reality )
74- precomps = generate_precomputes (L , spin , sampling , forward = False , L_lower = L_lower )
82+ precomps = generate_precomputes_jax (
83+ L , spin , sampling , forward = False , L_lower = L_lower
84+ )
85+ if method == "numpy" :
86+ precomps = _jax_arrays_to_numpy (precomps )
7587 return {"flm" : flm , "precomps" : precomps }
7688
7789
0 commit comments