1616SAMPLING_VALUES = ["mw" ]
1717METHOD_VALUES = ["numpy" , "jax" ]
1818REALITY_VALUES = [True ]
19- SPMD_VALUES = [False ]
2019
2120
22- def setup_forward (method , L , L_lower , N , sampling , reality , spmd ):
23- if spmd and method != "jax" :
24- skip ("GPU distribution only valid for JAX." )
21+ def setup_forward (method , L , L_lower , N , sampling , reality ):
2522 rng = np .random .default_rng ()
2623 flmn = s2fft .utils .signal_generator .generate_flmn (rng , L , N , reality = reality )
2724 f = base_wigner .inverse (
@@ -51,27 +48,23 @@ def setup_forward(method, L, L_lower, N, sampling, reality, spmd):
5148 N = N_VALUES ,
5249 sampling = SAMPLING_VALUES ,
5350 reality = REALITY_VALUES ,
54- spmd = SPMD_VALUES ,
5551)
56- def forward (f , precomps , method , L , L_lower , N , sampling , reality , spmd ):
52+ def forward (f , precomps , method , L , L_lower , N , sampling , reality ):
5753 flmn = s2fft .transforms .wigner .forward (
5854 f = f ,
5955 L = L ,
60- L_lower = L_lower ,
6156 N = N ,
62- precomps = precomps ,
6357 sampling = sampling ,
64- reality = reality ,
6558 method = method ,
66- spmd = spmd ,
59+ reality = reality ,
60+ precomps = precomps ,
61+ L_lower = L_lower ,
6762 )
6863 if method == "jax" :
6964 flmn .block_until_ready ()
7065
7166
72- def setup_inverse (method , L , L_lower , N , sampling , reality , spmd ):
73- if spmd and method != "jax" :
74- skip ("GPU distribution only valid for JAX." )
67+ def setup_inverse (method , L , L_lower , N , sampling , reality ):
7568 rng = np .random .default_rng ()
7669 flmn = s2fft .utils .signal_generator .generate_flmn (rng , L , N , reality = reality )
7770 generate_precomputes = (
@@ -93,19 +86,17 @@ def setup_inverse(method, L, L_lower, N, sampling, reality, spmd):
9386 N = N_VALUES ,
9487 sampling = SAMPLING_VALUES ,
9588 reality = REALITY_VALUES ,
96- spmd = SPMD_VALUES ,
9789)
98- def inverse (flmn , precomps , method , L , L_lower , N , sampling , reality , spmd ):
99- f = s2fft .transforms .spherical .inverse (
100- flm = flmn ,
90+ def inverse (flmn , precomps , method , L , L_lower , N , sampling , reality ):
91+ f = s2fft .transforms .wigner .inverse (
92+ flmn = flmn ,
10193 L = L ,
102- L_lower = L_lower ,
10394 N = N ,
104- precomps = precomps ,
10595 sampling = sampling ,
10696 reality = reality ,
10797 method = method ,
108- spmd = spmd ,
98+ precomps = precomps ,
99+ L_lower = L_lower ,
109100 )
110101 if method == "jax" :
111102 f .block_until_ready ()
0 commit comments