@@ -136,16 +136,6 @@ def test_trapani_interfaces():
136136 recursions .trapani .compute_full (dl_jax , L , el , implementation = "unexpected" )
137137
138138
139- def test_trapani_checks ():
140- # TODO
141-
142- # Check throws exception if arguments wrong
143-
144- # Check throws exception if don't init
145-
146- return
147-
148-
149139def test_risbo_with_ssht ():
150140 """Test Risbo computation against ssht"""
151141
@@ -171,15 +161,16 @@ def test_risbo_with_ssht_jax():
171161 L = 10
172162
173163 # Compute using SSHT.
174- beta = np .pi / 2.0
175- dl_array = ssht .generate_dl (beta , L )
164+ betas = [0 , np .pi / 2.0 , np .pi ]
165+ for beta in betas :
166+ dl_array = ssht .generate_dl (beta , L )
176167
177- # Compare to routines in SSHT, which have been validated extensively.
178- dl = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
168+ # Compare to routines in SSHT, which have been validated extensively.
169+ dl = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
179170
180- for el in range (0 , L ):
181- dl = recursions .risbo_jax .compute_full (dl , beta , L , el )
182- np .testing .assert_allclose (dl_array [el , :, :], dl , atol = 1e-15 )
171+ for el in range (0 , L ):
172+ dl = recursions .risbo_jax .compute_full (dl , beta , L , el )
173+ np .testing .assert_allclose (dl_array [el , :, :], dl , atol = 1e-15 )
183174
184175
185176@pytest .mark .parametrize ("L" , L_to_test )
0 commit comments