|
47 | 47 | ParameterValueError, |
48 | 48 | local_check_parameter_to_ninf_switch, |
49 | 49 | ) |
| 50 | +from pymc.model import Model |
| 51 | +from pymc.model.fgraph import fgraph_from_model |
50 | 52 | from pymc.pytensorf import compile, floatX, inputvars, rvs_in_graph |
51 | 53 |
|
52 | 54 | # This mode can be used for tests where model compilations takes the bulk of the runtime |
@@ -239,7 +241,7 @@ def build_model(distfam, valuedomain, vardomains, extra_args=None): |
239 | 241 | if extra_args is None: |
240 | 242 | extra_args = {} |
241 | 243 |
|
242 | | - with pm.Model() as m: |
| 244 | + with Model() as m: |
243 | 245 | param_vars = {} |
244 | 246 | for v, dom in vardomains.items(): |
245 | 247 | v_pt = pytensor.shared(np.asarray(dom.vals[0])) |
@@ -1209,3 +1211,37 @@ def equal_computations_up_to_root( |
1209 | 1211 | return False |
1210 | 1212 |
|
1211 | 1213 | return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs) # type: ignore[arg-type] |
| 1214 | + |
| 1215 | + |
| 1216 | +def assert_equivalent_models(model1: Model, model2: Model): |
| 1217 | + """Check whether two PyMC models are equivalent. |
| 1218 | +
|
| 1219 | + Examples |
| 1220 | + -------- |
| 1221 | +
|
| 1222 | + .. code-block:: python |
| 1223 | +
|
| 1224 | + import pymc as pm |
| 1225 | + from pymc_extras.utils.model_equivalence import equivalent_models |
| 1226 | +
|
| 1227 | + with pm.Model() as m1: |
| 1228 | + x = pm.Normal("x") |
| 1229 | + y = pm.Normal("y", x) |
| 1230 | +
|
| 1231 | + with pm.Model() as m2: |
| 1232 | + x = pm.Normal("x") |
| 1233 | + y = pm.Normal("y", x + 1) |
| 1234 | +
|
| 1235 | + with pm.Model() as m3: |
| 1236 | + x = pm.Normal("x") |
| 1237 | + y = pm.Normal("y", x) |
| 1238 | +
|
| 1239 | + assert not equivalent_models(m1, m2) |
| 1240 | + assert equivalent_models(m1, m3) |
| 1241 | +
|
| 1242 | + """ |
| 1243 | + fgraph1, _ = fgraph_from_model(model1) |
| 1244 | + fgraph2, _ = fgraph_from_model(model2) |
| 1245 | + |
| 1246 | + are_equivalent = equal_computations_up_to_root(fgraph1.outputs, fgraph2.outputs) |
| 1247 | + assert are_equivalent, "Models are not equivalent" |
0 commit comments