diff --git a/pymc/model/core.py b/pymc/model/core.py index dc09288bc..66e633e15 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -52,7 +52,6 @@ from pymc.logprob.basic import transformed_conditional_logp from pymc.logprob.transforms import Transform from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values -from pymc.model_graph import model_to_graphviz, model_to_mermaid from pymc.pytensorf import ( PointFunc, SeedSequenceSeed, @@ -440,6 +439,8 @@ def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: def _display_(self): import marimo as mo + from pymc.model_graph import model_to_mermaid + return mo.mermaid(model_to_mermaid(self)) @staticmethod @@ -2002,6 +2003,8 @@ def to_graphviz( # creates the file `schools.pdf` schools.to_graphviz().render("schools") """ + from pymc.model_graph import model_to_graphviz + return model_to_graphviz( model=self, var_names=var_names, diff --git a/pymc/model_graph.py b/pymc/model_graph.py index e62e5e244..50fd5227d 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -21,16 +21,11 @@ from typing import Any, cast from pytensor import function -from pytensor.graph import Apply from pytensor.graph.basic import ancestors, walk -from pytensor.scalar.basic import Cast -from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.shape import Shape from pytensor.tensor.variable import TensorVariable -import pymc as pm - +from pymc.model.core import modelcontext from pymc.util import VarName, get_default_varnames, get_var_name __all__ = ( @@ -241,42 +236,32 @@ class ModelGraph: def __init__(self, model): self.model = model self._all_var_names = get_default_varnames(self.model.named_vars, include_transformed=False) + self._all_vars = {model[var_name] for var_name in self._all_var_names} self.var_list = self.model.named_vars.values() def get_parent_names(self, var: TensorVariable) -> set[VarName]: - if var.owner is None or var.owner.inputs is None: + if var.owner is None: return set() - def _filter_non_parameter_inputs(var): - node = var.owner - if isinstance(node.op, Shape): - # Don't show shape-related dependencies - return [] - if isinstance(node.op, RandomVariable): - # Filter out rng and size parameters or RandomVariable nodes - return node.op.dist_params(node) - else: - # Otherwise return all inputs - return node.inputs - - blockers = set(self.model.named_vars) + named_vars = self._all_vars def _expand(x): - nonlocal blockers - if x.name in blockers: + if x in named_vars: + # Don't go beyond named_vars return [x] - if isinstance(x.owner, Apply): - return reversed(_filter_non_parameter_inputs(x)) - return [] - - parents = set() - for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand): - # Only consider nodes that are in the named model variables. - vname = getattr(x, "name", None) - if isinstance(vname, str) and vname in self._all_var_names: - parents.add(VarName(vname)) + if x.owner is None: + return [] + if isinstance(x.owner.op, Shape): + # Don't propagate shape-related dependencies + return [] + # Continue walking the graph through the inputs + return x.owner.inputs - return parents + return { + cast(VarName, ancestor.name) # type: ignore[union-attr] + for ancestor in walk(nodes=var.owner.inputs, expand=_expand) + if ancestor in named_vars + } def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]: if var_names is None: @@ -312,35 +297,28 @@ def make_compute_graph( self, var_names: Iterable[VarName] | None = None ) -> dict[VarName, set[VarName]]: """Get map of var_name -> set(input var names) for the model.""" + model = self.model + named_vars = self._all_vars input_map: dict[VarName, set[VarName]] = defaultdict(set) - for var_name in self.vars_to_plot(var_names): - var = self.model[var_name] - parent_name = self.get_parent_names(var) - input_map[var_name] = input_map[var_name].union(parent_name) - - if var in self.model.observed_RVs: - obs_node = self.model.rvs_to_values[var] - - # loop created so that the elif block can go through this again - # and remove any intermediate ops, notably dtype casting, to observations - while True: - obs_name = obs_node.name - if obs_name and obs_name != var_name: - input_map[var_name] = input_map[var_name].difference({obs_name}) - input_map[obs_name] = input_map[obs_name].union({var_name}) - break - elif ( - # for cases where observations are cast to a certain dtype - # see issue 5795: https://github.com/pymc-devs/pymc/issues/5795 - obs_node.owner - and isinstance(obs_node.owner.op, Elemwise) - and isinstance(obs_node.owner.op.scalar_op, Cast) - ): - # we can retrieve the observation node by going up the graph - obs_node = obs_node.owner.inputs[0] - else: - break + var_names_to_plot = self.vars_to_plot(var_names) + for var_name in var_names_to_plot: + parent_names = self.get_parent_names(model[var_name]) + input_map[var_name].update(parent_names) + + for var_name in var_names_to_plot: + if (var := model[var_name]) in model.observed_RVs: + # Make observed `Data` variables flow from the observed RV, and not the other way around + # (In the generative graph they usually inform shape of the observed RV) + # We have to iterate over the ancestors of the observed values because there can be + # deterministic operations in between the `Data` variable and the observed value. + obs_var = model.rvs_to_values[var] + for ancestor in ancestors([obs_var]): + if ancestor not in named_vars: + continue + obs_name = cast(VarName, ancestor.name) + input_map[var_name].discard(obs_name) + input_map[obs_name].add(var_name) return input_map @@ -361,7 +339,7 @@ def get_plates( plates = defaultdict(set) # TODO: Evaluate all RV shapes at once - # This should help find discrepencies, and + # This should help find discrepancies, and # avoids unnecessary function compiles for determining labels. dim_lengths: dict[str, int] = { dim_name: fast_eval(value).item() for dim_name, value in self.model.dim_lengths.items() @@ -662,7 +640,7 @@ def model_to_networkx( stacklevel=2, ) - model = pm.modelcontext(model) + model = modelcontext(model) graph = ModelGraph(model) return make_networkx( name=model.name, @@ -777,7 +755,7 @@ def model_to_graphviz( stacklevel=2, ) - model = pm.modelcontext(model) + model = modelcontext(model) graph = ModelGraph(model) return make_graph( model.name, @@ -910,7 +888,7 @@ def model_to_mermaid(model=None, *, var_names=None, include_dim_lengths: bool = """ - model = pm.modelcontext(model) + model = modelcontext(model) graph = ModelGraph(model) plates = sorted(graph.get_plates(var_names=var_names), key=lambda plate: hash(plate.dim_info)) edges = sorted(graph.edges(var_names=var_names)) diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 409e255d7..032fbc938 100755 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -168,7 +168,7 @@ def check_no_unexpected_results(mypy_lines: Iterator[str]): for section, sdf in df.reset_index().groupby(args.groupby): print(f"\n\n[{section}]") for row in sdf.itertuples(): - print(f"{row.file}:{row.line}: {row.type}: {row.message}") + print(f"{row.file}:{row.line}: {row.type} [{row.errorcode}]: {row.message}") print() else: print( diff --git a/tests/model/test_core.py b/tests/model/test_core.py index a12046c31..814cb114d 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -1750,7 +1750,7 @@ def school_model(J: int) -> pm.Model: ) def test_graphviz_call_function(self, var_names, filenames) -> None: model = self.school_model(J=8) - with patch("pymc.model.core.model_to_graphviz") as mock_model_to_graphviz: + with patch("pymc.model_graph.model_to_graphviz") as mock_model_to_graphviz: model.to_graphviz(var_names=var_names, save=filenames) mock_model_to_graphviz.assert_called_once_with( model=model, diff --git a/tests/test_model_graph.py b/tests/test_model_graph.py index 1ea1b694b..94b797f8e 100644 --- a/tests/test_model_graph.py +++ b/tests/test_model_graph.py @@ -516,14 +516,39 @@ def test_model_graph_with_intermediate_named_variables(): with pm.Model() as m1: a = pm.Normal("a", 0, 1, shape=3) pm.Normal("b", a.mean(axis=-1), 1) - assert dict(ModelGraph(m1).make_compute_graph()) == {"a": set(), "b": {"a"}} + assert ModelGraph(m1).make_compute_graph() == {"a": set(), "b": {"a"}} with pm.Model() as m2: a = pm.Normal("a", 0, 1) b = a + 1 b.name = "b" pm.Normal("c", b, 1) - assert dict(ModelGraph(m2).make_compute_graph()) == {"a": set(), "c": {"a"}} + assert ModelGraph(m2).make_compute_graph() == {"a": set(), "c": {"a"}} + + # Regression test for https://github.com/pymc-devs/pymc/issues/7397 + with pm.Model() as m3: + data = pt.as_tensor_variable( + np.ones((5, 3)), + name="C", + ) + # C has the same name as `data` variable + # This used to be wrongly picked up as a dependency + C = pm.Deterministic("C", data) + # D depends on a variable called `C` but this is not really one in the model + D = pm.Deterministic("D", data) + # This actually depends on the model variable `C` + E = pm.Deterministic("E", C) + assert ModelGraph(m3).make_compute_graph() == {"C": set(), "D": set(), "E": {"C"}} + + +def test_model_graph_complex_observed_dependency(): + with pm.Model() as model: + x = pm.Data("x", [0]) + y = pm.Data("y", [0]) + observed = pt.exp(x) + pt.log(y) + pm.Normal("obs", mu=0, observed=observed) + + assert ModelGraph(model).make_compute_graph() == {"obs": set(), "x": {"obs"}, "y": {"obs"}} @pytest.fixture