@@ -368,6 +368,7 @@ def f(c):
368368 # should not be zero.
369369 self .assertNotEqual (jax .grad (f )(0. ), 0 )
370370
371+ @chex .all_variants (with_pmap = False , with_device = False )
371372 @parameterized .named_parameters (
372373 ('small concentration' , 1. ),
373374 ('medium concentration' , 10. ),
@@ -382,29 +383,26 @@ def f(seed, l, c):
382383 vm = self .distrax_cls (l , c )
383384 x = vm .sample (seed = seed ) # pylint: disable=cell-var-from-loop
384385 return x
385- jax_sample_and_grad = jax .value_and_grad (f , argnums = 2 )
386+ jax_sample_and_grad = self . variant ( jax .value_and_grad (f , argnums = 2 ) )
386387
387388 def samples_grad (s , concentration ):
388389 broadcast_concentration = concentration
389390 _ , dcdf_dconcentration = tfp .math .value_and_gradient (
390391 lambda conc : tfp .distributions .von_mises .von_mises_cdf (s , conc ),
391392 broadcast_concentration )
392393 inv_prob = np .exp (- concentration * (np .cos (s ) - 1. )) * (
393- (2. * np .pi ) * scipy .special .i0e (concentration )
394- )
394+ (2. * np .pi ) * scipy .special .i0e (concentration ))
395395 # Computes the implicit derivative,
396396 # dz = dconc * -(dF(z; conc) / dconc) / p(z; conc)
397397 dsamples = - dcdf_dconcentration * inv_prob
398398 return dsamples
399399
400400 for seed in range (10 ):
401401 sample , sample_grad = jax_sample_and_grad (
402- seed , jnp .array (locs ), jnp .array (concentration )
403- )
402+ seed , jnp .array (locs ), jnp .array (concentration ))
404403 comparison = samples_grad (sample , concentration )
405404 np .testing .assert_allclose (
406- comparison , sample_grad , rtol = 1e-06 , atol = 1e-06
407- )
405+ comparison , sample_grad , rtol = 1e-06 , atol = 1e-06 )
408406
409407 def test_von_mises_sample_moments (self ):
410408 locs_v = np .array ([- 1. , 0.3 , 2.3 ])
0 commit comments