88from s2fft .transforms import wigner
99from s2fft .recursions .price_mcewen import generate_precomputes_wigner_jax
1010
11- L_to_test = [16 ]
12- N_to_test = [2 ]
13- L_lower_to_test = [0 , 2 ]
11+ L_to_test = [6 ]
12+ N_to_test = [3 ]
13+ L_lower_to_test = [1 ]
1414sampling_to_test = ["mw" , "mwss" , "dh" ]
15- reality_to_test = [False ]
16- multiple_gpus = [False ]
15+ reality_to_test = [False , True ]
1716
1817
1918@pytest .mark .parametrize ("L" , L_to_test )
2019@pytest .mark .parametrize ("N" , N_to_test )
2120@pytest .mark .parametrize ("L_lower" , L_lower_to_test )
2221@pytest .mark .parametrize ("sampling" , sampling_to_test )
2322@pytest .mark .parametrize ("reality" , reality_to_test )
24- @pytest .mark .parametrize ("spmd" , multiple_gpus )
2523@pytest .mark .filterwarnings ("ignore::RuntimeWarning" )
2624def test_inverse_wigner_custom_gradients (
2725 flmn_generator ,
@@ -30,7 +28,6 @@ def test_inverse_wigner_custom_gradients(
3028 L_lower : int ,
3129 sampling : str ,
3230 reality : bool ,
33- spmd : bool ,
3431):
3532 precomps = generate_precomputes_wigner_jax (
3633 L , N , sampling , None , False , reality , L_lower
@@ -39,12 +36,12 @@ def test_inverse_wigner_custom_gradients(
3936 flmn = flmn_generator (L = L , N = N , L_lower = L_lower , reality = reality )
4037 flmn_target = flmn_generator (L = L , N = N , L_lower = L_lower , reality = reality )
4138 f_target = wigner .inverse_jax (
42- flmn_target , L , N , None , sampling , reality , precomps , spmd , L_lower
39+ flmn_target , L , N , None , sampling , reality , precomps , False , L_lower
4340 )
4441
4542 def func (flmn ):
4643 f = wigner .inverse_jax (
47- flmn , L , N , None , sampling , reality , precomps , spmd , L_lower
44+ flmn , L , N , None , sampling , reality , precomps , False , L_lower
4845 )
4946 return jnp .sum (jnp .abs (f - f_target ) ** 2 )
5047
@@ -56,7 +53,6 @@ def func(flmn):
5653@pytest .mark .parametrize ("L_lower" , L_lower_to_test )
5754@pytest .mark .parametrize ("sampling" , sampling_to_test )
5855@pytest .mark .parametrize ("reality" , reality_to_test )
59- @pytest .mark .parametrize ("spmd" , multiple_gpus )
6056@pytest .mark .filterwarnings ("ignore::RuntimeWarning" )
6157def test_forward_wigner_custom_gradients (
6258 flmn_generator ,
@@ -65,7 +61,6 @@ def test_forward_wigner_custom_gradients(
6561 L_lower : int ,
6662 sampling : str ,
6763 reality : bool ,
68- spmd : bool ,
6964):
7065 precomps = generate_precomputes_wigner_jax (
7166 L , N , sampling , None , True , reality , L_lower
@@ -74,12 +69,12 @@ def test_forward_wigner_custom_gradients(
7469 flmn_target = flmn_generator (L = L , N = N , L_lower = L_lower , reality = reality )
7570 flmn = flmn_generator (L = L , N = N , L_lower = L_lower , reality = reality )
7671 f = wigner .inverse_jax (
77- flmn , L , N , None , sampling , reality , None , spmd , L_lower
72+ flmn , L , N , None , sampling , reality , None , False , L_lower
7873 )
7974
8075 def func (f ):
8176 flmn = wigner .forward_jax (
82- f , L , N , None , sampling , reality , precomps , spmd , L_lower
77+ f , L , N , None , sampling , reality , precomps , False , L_lower
8378 )
8479 return jnp .sum (jnp .abs (flmn - flmn_target ) ** 2 )
8580
0 commit comments