|
47 | 47 | ParameterValueError, |
48 | 48 | local_check_parameter_to_ninf_switch, |
49 | 49 | ) |
50 | | -from pymc.model import Model |
51 | | -from pymc.model.fgraph import fgraph_from_model |
52 | 50 | from pymc.pytensorf import compile, floatX, inputvars, rvs_in_graph |
53 | 51 |
|
54 | 52 | # This mode can be used for tests where model compilations takes the bulk of the runtime |
@@ -241,7 +239,7 @@ def build_model(distfam, valuedomain, vardomains, extra_args=None): |
241 | 239 | if extra_args is None: |
242 | 240 | extra_args = {} |
243 | 241 |
|
244 | | - with Model() as m: |
| 242 | + with pm.Model() as m: |
245 | 243 | param_vars = {} |
246 | 244 | for v, dom in vardomains.items(): |
247 | 245 | v_pt = pytensor.shared(np.asarray(dom.vals[0])) |
@@ -1211,37 +1209,3 @@ def equal_computations_up_to_root( |
1211 | 1209 | return False |
1212 | 1210 |
|
1213 | 1211 | return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs) # type: ignore[arg-type] |
1214 | | - |
1215 | | - |
1216 | | -def assert_equivalent_models(model1: Model, model2: Model): |
1217 | | - """Check whether two PyMC models are equivalent. |
1218 | | -
|
1219 | | - Examples |
1220 | | - -------- |
1221 | | -
|
1222 | | - .. code-block:: python |
1223 | | -
|
1224 | | - import pymc as pm |
1225 | | - from pymc_extras.utils.model_equivalence import equivalent_models |
1226 | | -
|
1227 | | - with pm.Model() as m1: |
1228 | | - x = pm.Normal("x") |
1229 | | - y = pm.Normal("y", x) |
1230 | | -
|
1231 | | - with pm.Model() as m2: |
1232 | | - x = pm.Normal("x") |
1233 | | - y = pm.Normal("y", x + 1) |
1234 | | -
|
1235 | | - with pm.Model() as m3: |
1236 | | - x = pm.Normal("x") |
1237 | | - y = pm.Normal("y", x) |
1238 | | -
|
1239 | | - assert not equivalent_models(m1, m2) |
1240 | | - assert equivalent_models(m1, m3) |
1241 | | -
|
1242 | | - """ |
1243 | | - fgraph1, _ = fgraph_from_model(model1) |
1244 | | - fgraph2, _ = fgraph_from_model(model2) |
1245 | | - |
1246 | | - are_equivalent = equal_computations_up_to_root(fgraph1.outputs, fgraph2.outputs) |
1247 | | - assert are_equivalent, "Models are not equivalent" |
0 commit comments