Skip to content

Commit 1c44b1e

Browse files
authored
Add tests for pathwise derivative (#1051)
1 parent ad35577 commit 1c44b1e

File tree

1 file changed

+24
-39
lines changed

1 file changed

+24
-39
lines changed

test/test_distributions.py

Lines changed: 24 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import scipy.stats as osp
1313

1414
import jax
15-
from jax import grad, jacfwd, lax, vmap
15+
from jax import grad, lax, vmap
1616
import jax.numpy as jnp
1717
import jax.random as random
1818
from jax.scipy.special import expit, logsumexp
@@ -680,48 +680,33 @@ def fn(args):
680680

681681

682682
@pytest.mark.parametrize(
683-
"jax_dist, sp_dist, params",
683+
"jax_dist, params",
684684
[
685-
(dist.Gamma, osp.gamma, (1.0,)),
686-
(dist.Gamma, osp.gamma, (0.1,)),
687-
(dist.Gamma, osp.gamma, (10.0,)),
688-
(dist.Chi2, osp.chi2, (1.0,)),
689-
(dist.Chi2, osp.chi2, (0.1,)),
690-
(dist.Chi2, osp.chi2, (10.0,)),
691-
# TODO: add more test cases for Beta/StudentT (and Dirichlet too) when
692-
# their pathwise grad (independent of standard_gamma grad) is implemented.
693-
pytest.param(
694-
dist.Beta,
695-
osp.beta,
696-
(1.0, 1.0),
697-
marks=pytest.mark.xfail(
698-
reason="currently, variance of grad of beta sampler is large"
699-
),
700-
),
701-
pytest.param(
702-
dist.StudentT,
703-
osp.t,
704-
(1.0,),
705-
marks=pytest.mark.xfail(
706-
reason="currently, variance of grad of t sampler is large"
707-
),
708-
),
685+
(dist.Gamma, (1.0,)),
686+
(dist.Gamma, (0.1,)),
687+
(dist.Gamma, (10.0,)),
688+
(dist.Chi2, (1.0,)),
689+
(dist.Chi2, (0.1,)),
690+
(dist.Chi2, (10.0,)),
691+
(dist.Beta, (1.0, 1.0)),
692+
(dist.StudentT, (5.0, 2.0, 4.0)),
709693
],
710694
)
711-
def test_pathwise_gradient(jax_dist, sp_dist, params):
695+
def test_pathwise_gradient(jax_dist, params):
712696
rng_key = random.PRNGKey(0)
713-
N = 100
714-
z = jax_dist(*params).sample(key=rng_key, sample_shape=(N,))
715-
actual_grad = jacfwd(lambda x: jax_dist(*x).sample(key=rng_key, sample_shape=(N,)))(
716-
params
717-
)
718-
eps = 1e-3
719-
for i in range(len(params)):
720-
args_lhs = [p if j != i else p - eps for j, p in enumerate(params)]
721-
args_rhs = [p if j != i else p + eps for j, p in enumerate(params)]
722-
cdf_dot = (sp_dist(*args_rhs).cdf(z) - sp_dist(*args_lhs).cdf(z)) / (2 * eps)
723-
expected_grad = -cdf_dot / sp_dist(*params).pdf(z)
724-
assert_allclose(actual_grad[i], expected_grad, rtol=0.005)
697+
N = 1000000
698+
699+
def f(params):
700+
z = jax_dist(*params).sample(key=rng_key, sample_shape=(N,))
701+
return (z + z ** 2).mean(0)
702+
703+
def g(params):
704+
d = jax_dist(*params)
705+
return d.mean + d.variance + d.mean ** 2
706+
707+
actual_grad = grad(f)(params)
708+
expected_grad = grad(g)(params)
709+
assert_allclose(actual_grad, expected_grad, rtol=0.005)
725710

726711

727712
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)