Skip to content

Commit e77e238

Browse files
committed
Raise ValueError if random variables are present in the logp graph
Aeppl allows for graphs containing random variables. PyMC models do not generally allow for this, with the current exception of models that include SimulatorRVs.
1 parent b895e40 commit e77e238

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

pymc/distributions/logprob.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from aeppl.logprob import logprob as logp_aeppl
2525
from aeppl.transforms import TransformValuesOpt
2626
from aesara.graph.basic import graph_inputs, io_toposort
27+
from aesara.tensor.random.op import RandomVariable
2728
from aesara.tensor.subtensor import (
2829
AdvancedIncSubtensor,
2930
AdvancedIncSubtensor1,
@@ -223,6 +224,26 @@ def joint_logpt(
223224
tmp_rvs_to_values, extra_rewrites=transform_opt, use_jacobian=jacobian, **kwargs
224225
)
225226

227+
# Raise if there are unexpected RandomVariables in the logp graph
228+
# Only SimulatorRVs are allowed
229+
from pymc.distributions.simulator import SimulatorRV
230+
231+
unexpected_rv_nodes = [
232+
node
233+
for node in aesara.graph.ancestors(list(temp_logp_var_dict.values()))
234+
if (
235+
node.owner
236+
and isinstance(node.owner.op, RandomVariable)
237+
and not isinstance(node.owner.op, SimulatorRV)
238+
)
239+
]
240+
if unexpected_rv_nodes:
241+
raise ValueError(
242+
f"Random variables detected in the logp graph: {unexpected_rv_nodes}.\n"
243+
"This can happen when DensityDist logp or Interval transform functions "
244+
"reference nonlocal variables."
245+
)
246+
226247
# aeppl returns the logpt for every single value term we provided to it. This includes
227248
# the extra values we plugged in above, so we filter those we actually wanted in the
228249
# same order they were given in.

pymc/tests/test_logprob.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Subtensor,
2929
)
3030

31+
from pymc import DensityDist
3132
from pymc.aesaraf import floatX, walk_model
3233
from pymc.distributions.continuous import HalfFlat, Normal, TruncatedNormal, Uniform
3334
from pymc.distributions.discrete import Bernoulli
@@ -217,3 +218,12 @@ def test_model_unchanged_logprob_access():
217218
model.logpt()
218219
new_inputs = set(aesara.graph.graph_inputs([c]))
219220
assert original_inputs == new_inputs
221+
222+
223+
def test_unexpected_rvs():
224+
with Model() as model:
225+
x = Normal("x")
226+
y = DensityDist("y", logp=lambda *args: x)
227+
228+
with pytest.raises(ValueError, match="^Random variables detected in the logp graph"):
229+
model.logpt()

pymc/tests/test_parallel_sampling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,11 @@ def test_spawn_densitydist_bound_method():
201201
N = 100
202202
with pm.Model() as model:
203203
mu = pm.Normal("mu", 0, 1)
204-
normal_dist = pm.Normal.dist(mu, 1, size=N)
205204

206-
def logp(x):
205+
def logp(x, mu):
206+
normal_dist = pm.Normal.dist(mu, 1, size=N)
207207
out = pm.logp(normal_dist, x)
208208
return out
209209

210-
obs = pm.DensityDist("density_dist", logp=logp, observed=np.random.randn(N), size=N)
210+
obs = pm.DensityDist("density_dist", mu, logp=logp, observed=np.random.randn(N), size=N)
211211
pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")

0 commit comments

Comments
 (0)