Skip to content

Commit c3a1fe5

Browse files
committed
added cannot_sample_rv
1 parent 7621508 commit c3a1fe5

File tree

3 files changed

+30
-6
lines changed

3 files changed

+30
-6
lines changed

pymc/distributions/multivariate.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -619,12 +619,6 @@ def dist(cls, n, p, *args, **kwargs):
619619
return super().dist([n, p], *args, **kwargs)
620620

621621
def support_point(rv, size, n, p):
622-
observed = getattr(rv.tag, "observed", None)
623-
if observed is None:
624-
raise ValueError(
625-
"Latent Multinomial variables are not supported for sampling. "
626-
"Use a Categorical variable instead."
627-
)
628622
n = pt.shape_padright(n)
629623
mean = n * p
630624
mode = pt.round(mean)

pymc/sampling/mcmc.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
5252
from pymc.backends.zarr import ZarrChain, ZarrTrace
5353
from pymc.blocking import DictToArrayBijection
54+
from pymc.distributions.multivariate import Multinomial
5455
from pymc.exceptions import SamplingError
5556
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
5657
from pymc.model import Model, modelcontext
@@ -63,6 +64,7 @@
6364
)
6465
from pymc.step_methods import NUTS, CompoundStep
6566
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
67+
from pymc.step_methods.cannot_sample import CannotSampleRV
6668
from pymc.step_methods.hmc import quadpotential
6769
from pymc.util import (
6870
ProgressBarManager,
@@ -144,6 +146,13 @@ def instantiate_steppers(
144146
if initial_point is None:
145147
initial_point = model.initial_point()
146148

149+
for rv in model.free_RVs:
150+
if isinstance(rv.owner.op, Multinomial) and getattr(rv.tag, "observed", None) is None:
151+
for step_class in list(selected_steps.keys()):
152+
if rv in selected_steps[step_class]:
153+
selected_steps[step_class].remove(rv)
154+
selected_steps.setdefault(CannotSampleRV, []).append(rv)
155+
147156
for step_class, vars in selected_steps.items():
148157
if vars:
149158
name = getattr(step_class, "name")

pymc/step_methods/cannot_sample.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from pymc.step_methods.arraystep import ArrayStep
2+
3+
class CannotSampleRV(ArrayStep):
4+
"""
5+
A step method that raises an error when sampling a latent Multinomial variable.
6+
"""
7+
name = "cannot_sample_rv"
8+
def __init__(self, vars, **kwargs):
9+
# Remove keys that ArrayStep.__init__ does not accept.
10+
kwargs.pop("model", None)
11+
kwargs.pop("initial_point", None)
12+
kwargs.pop("compile_kwargs", None)
13+
self.vars = vars
14+
super().__init__(vars=vars,fs=[], **kwargs)
15+
16+
def astep(self, q0):
17+
# This method is required by the abstract base class.
18+
raise ValueError(
19+
"Latent Multinomial variables are not supported"
20+
)
21+

0 commit comments

Comments
 (0)