|
35 | 35 |
|
36 | 36 | from pymc import ImputationWarning
|
37 | 37 | from pymc.distributions.multivariate import DirichletMultinomial, PosDefMatrix
|
| 38 | +from pymc.model.transform.optimization import freeze_dims_and_data |
38 | 39 | from pymc.sampling.jax import (
|
39 | 40 | _get_batched_jittered_initial_points,
|
40 | 41 | _get_log_likelihood,
|
@@ -514,6 +515,24 @@ def test_convergence_warnings(caplog, nuts_sampler):
|
514 | 515 |
|
515 | 516 |
|
516 | 517 | def test_dirichlet_multinomial():
|
| 518 | + """Test we can draw from a DM in the JAX backend if the shape is constant.""" |
517 | 519 | dm = DirichletMultinomial.dist(n=5, a=np.eye(3) * 1e6 + 0.01)
|
518 | 520 | dm_draws = pm.draw(dm, mode="JAX")
|
519 | 521 | np.testing.assert_equal(dm_draws, np.eye(3) * 5)
|
| 522 | + |
| 523 | + |
| 524 | +def test_dirichlet_multinomial_dims(): |
| 525 | + """Test we can draw from a DM with a shape defined by dims in the JAX backend, |
| 526 | + after freezing those dims. |
| 527 | + """ |
| 528 | + with pm.Model(coords={"trial": range(3), "item": range(3)}) as m: |
| 529 | + dm = DirichletMultinomial("dm", n=5, a=np.eye(3) * 1e6 + 0.01, dims=("trial", "item")) |
| 530 | + |
| 531 | + # JAX does not allow us to JIT a function with dynamic shape |
| 532 | + with pytest.raises(TypeError): |
| 533 | + pm.draw(dm, mode="JAX") |
| 534 | + |
| 535 | + # Should be fine after freezing the dims that specify the shape |
| 536 | + frozen_dm = freeze_dims_and_data(m)["dm"] |
| 537 | + dm_draws = pm.draw(frozen_dm, mode="JAX") |
| 538 | + np.testing.assert_equal(dm_draws, np.eye(3) * 5) |
0 commit comments