Skip to content

Commit c495d8a

Browse files
ricardoV94twiecki
authored andcommitted
Add test for freeze_dims_and_data in JAX backend
1 parent 40fb76c commit c495d8a

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

tests/sampling/test_jax.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from pymc import ImputationWarning
3737
from pymc.distributions.multivariate import DirichletMultinomial, PosDefMatrix
38+
from pymc.model.transform.optimization import freeze_dims_and_data
3839
from pymc.sampling.jax import (
3940
_get_batched_jittered_initial_points,
4041
_get_log_likelihood,
@@ -514,6 +515,24 @@ def test_convergence_warnings(caplog, nuts_sampler):
514515

515516

516517
def test_dirichlet_multinomial():
518+
"""Test we can draw from a DM in the JAX backend if the shape is constant."""
517519
dm = DirichletMultinomial.dist(n=5, a=np.eye(3) * 1e6 + 0.01)
518520
dm_draws = pm.draw(dm, mode="JAX")
519521
np.testing.assert_equal(dm_draws, np.eye(3) * 5)
522+
523+
524+
def test_dirichlet_multinomial_dims():
525+
"""Test we can draw from a DM with a shape defined by dims in the JAX backend,
526+
after freezing those dims.
527+
"""
528+
with pm.Model(coords={"trial": range(3), "item": range(3)}) as m:
529+
dm = DirichletMultinomial("dm", n=5, a=np.eye(3) * 1e6 + 0.01, dims=("trial", "item"))
530+
531+
# JAX does not allow us to JIT a function with dynamic shape
532+
with pytest.raises(TypeError):
533+
pm.draw(dm, mode="JAX")
534+
535+
# Should be fine after freezing the dims that specify the shape
536+
frozen_dm = freeze_dims_and_data(m)["dm"]
537+
dm_draws = pm.draw(frozen_dm, mode="JAX")
538+
np.testing.assert_equal(dm_draws, np.eye(3) * 5)

0 commit comments

Comments
 (0)