Skip to content

Commit fafc020

Browse files
committed
Use better assert_no_rvs from logprob submodule
This utility can find RVs in inner graphs
1 parent bae121a commit fafc020

File tree

2 files changed

+19
-31
lines changed

2 files changed

+19
-31
lines changed

pymc/testing.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from numpy import random as nr
2525
from numpy import testing as npt
2626
from pytensor.compile.mode import Mode
27-
from pytensor.graph.basic import ancestors
27+
from pytensor.graph.basic import walk
28+
from pytensor.graph.op import HasInnerGraph
2829
from pytensor.graph.rewriting.basic import in2out
2930
from pytensor.tensor import TensorVariable
3031
from pytensor.tensor.random.op import RandomVariable
@@ -37,7 +38,7 @@
3738
from pymc.distributions.shape_utils import change_dist_size
3839
from pymc.initial_point import make_initial_point_fn
3940
from pymc.logprob import joint_logp
40-
from pymc.logprob.abstract import icdf
41+
from pymc.logprob.abstract import MeasurableVariable, icdf
4142
from pymc.logprob.utils import ParameterValueError
4243
from pymc.pytensorf import (
4344
compile_pymc,
@@ -958,5 +959,18 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable:
958959

959960

960961
def assert_no_rvs(var):
961-
assert not any(isinstance(v.owner.op, RandomVariable) for v in ancestors([var]) if v.owner)
962-
return var
962+
"""Assert that there are no `MeasurableVariable` nodes in a graph."""
963+
964+
def expand(r):
965+
owner = r.owner
966+
if owner:
967+
inputs = list(reversed(owner.inputs))
968+
969+
if isinstance(owner.op, HasInnerGraph):
970+
inputs += owner.op.inner_outputs
971+
972+
return inputs
973+
974+
for v in walk([var], expand, False):
975+
if v.owner and isinstance(v.owner.op, (RandomVariable, MeasurableVariable)):
976+
raise AssertionError(f"RV found in graph: {v}")

tests/logprob/utils.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,11 @@
3939
import numpy as np
4040

4141
from pytensor import tensor as pt
42-
from pytensor.graph.basic import walk
43-
from pytensor.graph.op import HasInnerGraph
4442
from pytensor.tensor.var import TensorVariable
4543
from scipy import stats as stats
4644

4745
from pymc.logprob import factorized_joint_logprob
48-
from pymc.logprob.abstract import (
49-
MeasurableVariable,
50-
get_measurable_outputs,
51-
icdf,
52-
logcdf,
53-
logprob,
54-
)
46+
from pymc.logprob.abstract import get_measurable_outputs, icdf, logcdf, logprob
5547
from pymc.logprob.utils import ignore_logprob
5648

5749

@@ -82,24 +74,6 @@ def joint_logprob(*args, sum: bool = True, **kwargs) -> Optional[TensorVariable]
8274
return pt.add(*logprob.values())
8375

8476

85-
def assert_no_rvs(var):
86-
"""Assert that there are no `MeasurableVariable` nodes in a graph."""
87-
88-
def expand(r):
89-
owner = r.owner
90-
if owner:
91-
inputs = list(reversed(owner.inputs))
92-
93-
if isinstance(owner.op, HasInnerGraph):
94-
inputs += owner.op.inner_outputs
95-
96-
return inputs
97-
98-
for v in walk([var], expand, False):
99-
if v.owner and isinstance(v.owner.op, MeasurableVariable):
100-
raise AssertionError(f"Variable {v} is a MeasurableVariable")
101-
102-
10377
def simulate_poiszero_hmm(
10478
N, mu=10.0, pi_0_a=np.r_[1, 1], p_0_a=np.r_[5, 1], p_1_a=np.r_[1, 1], seed=None
10579
):

0 commit comments

Comments
 (0)