Skip to content

Commit 1c39540

Browse files
committed
raise when sampling a multinomial
1 parent 0772383 commit 1c39540

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

pymc/distributions/multivariate.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,12 @@ def dist(cls, n, p, *args, **kwargs):
619619
return super().dist([n, p], *args, **kwargs)
620620

621621
def support_point(rv, size, n, p):
622+
observed = getattr(rv.tag, "observed", None)
623+
if observed is None:
624+
raise ValueError(
625+
"Latent Multinomial variables are not supported for sampling. "
626+
"Use a Categorical variable instead."
627+
)
622628
n = pt.shape_padright(n)
623629
mean = n * p
624630
mode = pt.round(mean)

tests/distributions/test_distribution.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,14 @@ def test_issue_4499(self):
8282
with pm.Model(check_bounds=False) as m:
8383
x = pm.DiracDelta("x", 1, size=10)
8484
npt.assert_almost_equal(m.compile_logp()({"x": np.ones(10)}), 0 * 10)
85-
85+
def test_issue_7548(self):
86+
#Test for bug in Multinomial, it should raise when trying to sample a Multinomial variable
87+
with pm.Model() as model:
88+
p = [0.3, 0.4, 0.3]
89+
n = 10
90+
x = pm.Multinomial("x", n=n, p=p)
91+
with pytest.raises(ValueError, match="Latent Multinomial variables are not supported"):
92+
pm.sample(draws=100, chains=1)
8693

8794
def test_all_distributions_have_support_points():
8895
import pymc.distributions as dist_module

0 commit comments

Comments
 (0)