|
12 | 12 | import scipy.stats as osp |
13 | 13 |
|
14 | 14 | import jax |
15 | | -from jax import grad, jacfwd, lax, vmap |
| 15 | +from jax import grad, lax, vmap |
16 | 16 | import jax.numpy as jnp |
17 | 17 | import jax.random as random |
18 | 18 | from jax.scipy.special import expit, logsumexp |
@@ -680,48 +680,33 @@ def fn(args): |
680 | 680 |
|
681 | 681 |
|
682 | 682 | @pytest.mark.parametrize( |
683 | | - "jax_dist, sp_dist, params", |
| 683 | + "jax_dist, params", |
684 | 684 | [ |
685 | | - (dist.Gamma, osp.gamma, (1.0,)), |
686 | | - (dist.Gamma, osp.gamma, (0.1,)), |
687 | | - (dist.Gamma, osp.gamma, (10.0,)), |
688 | | - (dist.Chi2, osp.chi2, (1.0,)), |
689 | | - (dist.Chi2, osp.chi2, (0.1,)), |
690 | | - (dist.Chi2, osp.chi2, (10.0,)), |
691 | | - # TODO: add more test cases for Beta/StudentT (and Dirichlet too) when |
692 | | - # their pathwise grad (independent of standard_gamma grad) is implemented. |
693 | | - pytest.param( |
694 | | - dist.Beta, |
695 | | - osp.beta, |
696 | | - (1.0, 1.0), |
697 | | - marks=pytest.mark.xfail( |
698 | | - reason="currently, variance of grad of beta sampler is large" |
699 | | - ), |
700 | | - ), |
701 | | - pytest.param( |
702 | | - dist.StudentT, |
703 | | - osp.t, |
704 | | - (1.0,), |
705 | | - marks=pytest.mark.xfail( |
706 | | - reason="currently, variance of grad of t sampler is large" |
707 | | - ), |
708 | | - ), |
| 685 | + (dist.Gamma, (1.0,)), |
| 686 | + (dist.Gamma, (0.1,)), |
| 687 | + (dist.Gamma, (10.0,)), |
| 688 | + (dist.Chi2, (1.0,)), |
| 689 | + (dist.Chi2, (0.1,)), |
| 690 | + (dist.Chi2, (10.0,)), |
| 691 | + (dist.Beta, (1.0, 1.0)), |
| 692 | + (dist.StudentT, (5.0, 2.0, 4.0)), |
709 | 693 | ], |
710 | 694 | ) |
711 | | -def test_pathwise_gradient(jax_dist, sp_dist, params): |
| 695 | +def test_pathwise_gradient(jax_dist, params): |
712 | 696 | rng_key = random.PRNGKey(0) |
713 | | - N = 100 |
714 | | - z = jax_dist(*params).sample(key=rng_key, sample_shape=(N,)) |
715 | | - actual_grad = jacfwd(lambda x: jax_dist(*x).sample(key=rng_key, sample_shape=(N,)))( |
716 | | - params |
717 | | - ) |
718 | | - eps = 1e-3 |
719 | | - for i in range(len(params)): |
720 | | - args_lhs = [p if j != i else p - eps for j, p in enumerate(params)] |
721 | | - args_rhs = [p if j != i else p + eps for j, p in enumerate(params)] |
722 | | - cdf_dot = (sp_dist(*args_rhs).cdf(z) - sp_dist(*args_lhs).cdf(z)) / (2 * eps) |
723 | | - expected_grad = -cdf_dot / sp_dist(*params).pdf(z) |
724 | | - assert_allclose(actual_grad[i], expected_grad, rtol=0.005) |
| 697 | + N = 1000000 |
| 698 | + |
| 699 | + def f(params): |
| 700 | + z = jax_dist(*params).sample(key=rng_key, sample_shape=(N,)) |
| 701 | + return (z + z ** 2).mean(0) |
| 702 | + |
| 703 | + def g(params): |
| 704 | + d = jax_dist(*params) |
| 705 | + return d.mean + d.variance + d.mean ** 2 |
| 706 | + |
| 707 | + actual_grad = grad(f)(params) |
| 708 | + expected_grad = grad(g)(params) |
| 709 | + assert_allclose(actual_grad, expected_grad, rtol=0.005) |
725 | 710 |
|
726 | 711 |
|
727 | 712 | @pytest.mark.parametrize( |
|
0 commit comments