Skip to content

Commit aa37af4

Browse files
franrruizDistraxDev
authored andcommitted
Fix VonMises tests.
PiperOrigin-RevId: 485835879
1 parent 1a47ddc commit aa37af4

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

distrax/_src/distributions/von_mises_test.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ def f(c):
368368
# should not be zero.
369369
self.assertNotEqual(jax.grad(f)(0.), 0)
370370

371+
@chex.all_variants(with_pmap=False, with_device=False)
371372
@parameterized.named_parameters(
372373
('small concentration', 1.),
373374
('medium concentration', 10.),
@@ -382,29 +383,26 @@ def f(seed, l, c):
382383
vm = self.distrax_cls(l, c)
383384
x = vm.sample(seed=seed) # pylint: disable=cell-var-from-loop
384385
return x
385-
jax_sample_and_grad = jax.value_and_grad(f, argnums=2)
386+
jax_sample_and_grad = self.variant(jax.value_and_grad(f, argnums=2))
386387

387388
def samples_grad(s, concentration):
388389
broadcast_concentration = concentration
389390
_, dcdf_dconcentration = tfp.math.value_and_gradient(
390391
lambda conc: tfp.distributions.von_mises.von_mises_cdf(s, conc),
391392
broadcast_concentration)
392393
inv_prob = np.exp(-concentration * (np.cos(s) - 1.)) * (
393-
(2. * np.pi) * scipy.special.i0e(concentration)
394-
)
394+
(2. * np.pi) * scipy.special.i0e(concentration))
395395
# Computes the implicit derivative,
396396
# dz = dconc * -(dF(z; conc) / dconc) / p(z; conc)
397397
dsamples = -dcdf_dconcentration * inv_prob
398398
return dsamples
399399

400400
for seed in range(10):
401401
sample, sample_grad = jax_sample_and_grad(
402-
seed, jnp.array(locs), jnp.array(concentration)
403-
)
402+
seed, jnp.array(locs), jnp.array(concentration))
404403
comparison = samples_grad(sample, concentration)
405404
np.testing.assert_allclose(
406-
comparison, sample_grad, rtol=1e-06, atol=1e-06
407-
)
405+
comparison, sample_grad, rtol=1e-06, atol=1e-06)
408406

409407
def test_von_mises_sample_moments(self):
410408
locs_v = np.array([-1., 0.3, 2.3])

0 commit comments

Comments
 (0)