Skip to content

Commit 7e581f6

Browse files
Add assert_equivalent_models test helper
1 parent 8f4673b commit 7e581f6

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

pymc/data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ class MinibatchIndexRV(IntegersRV):
9292
class MinibatchOp(OpFromGraph):
9393
"""Encapsulate Minibatch random draws in an opaque OFG."""
9494

95+
# FIXME: __props__ should not be empty
96+
__props__ = ()
97+
9598
def __init__(self, *args, **kwargs):
9699
super().__init__(*args, **kwargs, inline=True)
97100

pymc/testing.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
ParameterValueError,
4848
local_check_parameter_to_ninf_switch,
4949
)
50+
from pymc.model import Model
51+
from pymc.model.fgraph import fgraph_from_model
5052
from pymc.pytensorf import compile, floatX, inputvars, rvs_in_graph
5153

5254
# This mode can be used for tests where model compilations takes the bulk of the runtime
@@ -239,7 +241,7 @@ def build_model(distfam, valuedomain, vardomains, extra_args=None):
239241
if extra_args is None:
240242
extra_args = {}
241243

242-
with pm.Model() as m:
244+
with Model() as m:
243245
param_vars = {}
244246
for v, dom in vardomains.items():
245247
v_pt = pytensor.shared(np.asarray(dom.vals[0]))
@@ -1209,3 +1211,37 @@ def equal_computations_up_to_root(
12091211
return False
12101212

12111213
return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs) # type: ignore[arg-type]
1214+
1215+
1216+
def assert_equivalent_models(model1: Model, model2: Model):
1217+
"""Check whether two PyMC models are equivalent.
1218+
1219+
Examples
1220+
--------
1221+
1222+
.. code-block:: python
1223+
1224+
import pymc as pm
1225+
from pymc_extras.utils.model_equivalence import equivalent_models
1226+
1227+
with pm.Model() as m1:
1228+
x = pm.Normal("x")
1229+
y = pm.Normal("y", x)
1230+
1231+
with pm.Model() as m2:
1232+
x = pm.Normal("x")
1233+
y = pm.Normal("y", x + 1)
1234+
1235+
with pm.Model() as m3:
1236+
x = pm.Normal("x")
1237+
y = pm.Normal("y", x)
1238+
1239+
assert not equivalent_models(m1, m2)
1240+
assert equivalent_models(m1, m3)
1241+
1242+
"""
1243+
fgraph1, _ = fgraph_from_model(model1)
1244+
fgraph2, _ = fgraph_from_model(model2)
1245+
1246+
are_equivalent = equal_computations_up_to_root(fgraph1.outputs, fgraph2.outputs)
1247+
assert are_equivalent, "Models are not equivalent"

0 commit comments

Comments
 (0)